#System modules
import os, sys
import inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)

import importlib

#Personal files
sys.path.append(parentdir)
import fun_figs as ff
import functions_optosplit as opt
import functions_tiff_image_manipulation as ti
import functions_post_array_analysis as pa
import functions_wave_general as wg
importlib.reload(opt)
importlib.reload(ff)
importlib.reload(ti)
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

#General libraries

import matplotlib.pyplot as plt
from skimage import exposure
from matplotlib import patches
import numpy as np
import pandas as pd
import math
import time
import copy
import tifffile
import seaborn as sns
import random
from PIL import Image


def gen_video_df(v, df_vid):
    """
    Update the data frame subset with metadata and information from the file name
    """
    file_name0 = v.file_name0
    n_frames = v.n_frames
    duration = v.time_s
    shape = v.shape    
    df_vid.loc[df_vid.index[0], 'file_name0'] = file_name0
    df_vid.loc[df_vid.index[0], 'n_frames'] = n_frames
    df_vid.loc[df_vid.index[0], 'time_s'] = duration
    df_vid.loc[df_vid.index[0], 'shape'] = str(shape)
    return df_vid


def extract_crop_values_From_df(df, str='crop_area'):
    if not str in df.columns:
        return [0, 0, 0, 0] #i.e, no cropping...
    crop_str = df[str].values
    crop = [0, 0, 0, 0]
    if crop_str[0] == 'none' or crop_str[0] == ' ':
        return [0, 0, 0, 0] #i.e, no cropping...
    if crop_str[0] == crop_str[0]: #If the value is nan, then the value won't equal itself. However, if the object is not nan, then it will be same 
        if crop_str[0] == ' ':
            raise ValueError(f'{str} is empty...: {crop_str}')
        crop = crop_str[0].split(', ')
        crop = [int(i) for i in crop]   
    return crop

import nd2_handling2
def read_file_array_pixelation(df_vid, d_vid, settings_general):
    """
    Read file depending on the mode
    """
    print('\n  - Reading files')
    file_path = df_vid.file_path.values[0]
    dir_exp = df_vid.dir_exp.values[0]
    v = nd2_handling2.Video_waves(file_path, read_img_directly=False, frame_range=settings_general['frame_range'], dir_exp=dir_exp)
    v = nd2_handling2.update_v_from_nd2_spreadsheet_if_pims_could_not_extract_metadata(v, df_vid, settings_general)
    

    if settings_general['frame_range'][-1] >= v.n_frames:
        v.frame_range = range(0,v.n_frames)
    else:
        v.frame_range = settings_general['frame_range']
    print(f'\tVideo file #{v.file_nbr}: \n\t  - frame range: {v.frame_range}, \n\t  - frame to display: {v.frame_to_display}, \n\t  - p = {v.p} mbar, \n\t  - frame rate = {v.frame_rate} fps, \n\t  - cropping area = {v.crop}\n')

    if settings_general['mode'] == 'polarization':
        I_perpen, I_parallel, I_tot = opt.load_made_polarization_files(v, settings_general, dir_exp)
        imgs = [I_perpen, I_parallel, I_tot]
        if 'only_process_I_tot_polarization' in settings_general:
            if settings_general['only_process_I_tot_polarization']:
                imgs = [I_tot] 

    else:
        tic = time.perf_counter()
        print(f"Reading file {v.file_name0}")
        #If the selected frames to read are all set to 0, read the full range in all dimensions. Otherwise, read only the selected range of dimensions.
        v_with_img = nd2_handling2.Video_waves(file_path, read_img_directly=True, frame_range=settings_general['frame_range'], dir_exp=df_vid.dir_exp.values[0])
        img = v_with_img.img

        #Rescale image to uint8
        save_as_uint16 = False
        if 'save_as_uint16' in settings_general:
            if settings_general['save_as_uint16']:
                save_as_uint16 = True
            print('!!! NO RESCALING TO UINT8 !!!')
        if not save_as_uint16:
            img = ff.rescale_to_uint8(img) #Convert to unsigned integer 8-bit
        # Crop images according to cropping_area in the experiment spreadsheet
        print(f'333 {v.crop}')
        if not all(i == 0 for i in v.crop):      
            img = ff.crop_image(img, v.crop)
        toc = time.perf_counter()

        print(f'\tRead {img.shape[0]} frames from {v.file_path} in {toc - tic:0.2f} seconds') 
        imgs = [img]

    print('!!!1 raw img:', imgs[0].shape)
    #Rotate and crop images
    angle_transform = df_vid.angle.values[0]
    if angle_transform != 0 and angle_transform != ' ':
        print('Rotating the raw image stack')
    for i,img in enumerate(imgs):        
        if angle_transform != 0 and angle_transform != ' ':
            imgs[i]= transform_img(img, d_vid)
        if i == 0:
            print('!!!2 img after rotation but before cropping:', imgs[0].shape)
        #Crop the border with specified border width from the general settings
        imgs[i] = ff.crop_border(imgs[i], b=settings_general['crop_border_width'])
    print('!!!2 img after rotation and cropping:', imgs[0].shape)
    return imgs, v

def transform_img(img, d_vid, fillcolor=None):
    """
    img (np.array, uint8)
    Bilinear rotation of image
    angle: in degrees, counter-clockwise
    
    Note: quite slow as I rotate frame by frame in a for loop using PIL. Might be a faster way with Open CV. 
    However, I did not find a way to perform bilinear or bicubic interpolation with it.
    """
    if 'rot_flip' in d_vid:
        transform_mode = d_vid['rot_flip']
    else: 
        transform_mode = 'rot'
    # print(d_vid['angle'])
    if d_vid['angle'] != 0 and d_vid['angle'] != ' ' and (not isinstance(d_vid['angle'],(str))):
        img_rot = ff.transform_img(img, angle=d_vid['angle'], transform_mode=transform_mode, fillcolor=fillcolor)
        return img_rot
    else:
        return img
def find_out_what_MED_file_to_use(settings_general, v):
    #Find the path for the median file
    file_name_full = str(v.n_frames) + '_' + v.file_name0 + '.tif'
    MED_file = 'tot'
    if 'MED_file' in settings_general:
        if settings_general['MED_file'] == 'parallel':
            MED_file = 'parallel'
        elif settings_general['MED_file'] == 'perpen':
            MED_file = 'perpen'
        print(f'\t\tBasing the median file on {MED_file} image instead of the mean image of the two sides')
    if MED_file == 'tot':
        path_med = os.path.join(v.dir_path_file_folder, 'MED_' + file_name_full)   
    else:
        path_med = os.path.join(v.dir_path_file_folder, 'MED_'+ MED_file + '_' + file_name_full)     
    return path_med
def rotate_img_and_show_alignment(img_med, d_vid, settings_general, angle_transform=0):
    if angle_transform != 0 and angle_transform != ' ':
        print('Rotating the median image')
        img_med = transform_img(img_med, d_vid)
    else:
        angle_transform = 0
    if 'show_rotational_alignment' in settings_general:
        show_rotational_alignment = settings_general['show_rotational_alignment']
    else:
        show_rotational_alignment = False
    if settings_general['show_imgs'] and show_rotational_alignment:
        f,ax=plt.subplots(1,1,figsize=(10,10),sharex=True)
        plt.imshow(img_med, cmap='gray')
        plt.title(f'Median image, rotation = {angle_transform} degrees')
        h,w = img_med.shape
        pos_list = [0.1,0.5,0.9]
        for pos in pos_list:
            plt.axhline(h*pos, color='red', linewidth = 1)
            plt.axvline(w*pos, color='red', linewidth = 1)
        ax.axis('off')
        plt.show()
    return img_med

def retrieve_med_file(settings_general, v, df_vid, d_vid):
    """
    Read the median file if if exists, otherwise create it and save it.
    (based on the full number of frames
    """
    n_frames = v.n_frames
    #Check if the two polarization (full) stacks exist. If it doesn't, create them
    if settings_general['mode'] == 'polarization' and not settings_general['analyze_based_on_mean_value_of_the_two_sides_for_polarization']:
        raise ValueError('This option is deleted. see old file if you want to restore it')
        print('\tAnalyzing separated optosplit sides!')

        file_name_parallel_full = 'parallel_'+str(0)+'-'+str(v.n_frames-1)+'_'+v.file_name0+'.tif'
        path_med_I_tot = os.path.join(v.dir_path_file_folder, 'MED_' + file_name_parallel_full)  
        #Check if the two full stack polarization images exist, if they do not, create them.
        if os.path.exists(path_med_I_tot):
            img_med = ti.read_tif_file(path_med_I_tot)
        else:        
            #Create the files (full stacks so will be slow)
            img_perpen, img_parallel, I_tot = opt.load_and_split_imgs_optosplit(v.file_path, df_vid.dir_exp.values[0], crop=v.crop, save_to_file=True, frame_range=[0,0], subtract_bg = True)
            
            #Make median images
            img_med = create_med_file_small(I_tot, path_med_I_tot, df_vid, read_stack_from_file=False)                   
            del img_perpen, img_parallel, I_tot

    elif settings_general['mode'] == 'polarization' and settings_general['analyze_based_on_mean_value_of_the_two_sides_for_polarization']:
        print('\t- Mode: Polarization - Taking the median of the total intensity image.')

        path_med = find_out_what_MED_file_to_use(settings_general, v)
        #Check if the two full stack polarization images exist, if they do not, create them.
        if os.path.exists(path_med):
            img_med = ti.read_tif_file(path_med)                    
        else:        
            #Create the files (full stacks so will be slow)
            dir_exp = df_vid.dir_exp.values[0]
            # _, _, I_tot = opt.load_and_split_imgs_optosplit(v.file_path, settings_general, dir_exp, crop=v.crop, save_to_file=True, frame_range=[0,0], subtract_bg = True)
            I_perpen, I_parallel, I_tot = opt.load_made_polarization_files(v, settings_general, dir_exp, frame_range = [0,0])
            
            if 'MED_file' in settings_general:
                if settings_general['MED_file'] == 'parallel':
                    I = I_parallel
                elif settings_general['MED_file'] == 'perpen':
                    I = I_perpen
                else:
                    I = I_tot
            else:
                I = I_tot
            #Make median images
            img_med = create_med_file_small(I, path_med, df_vid, v, read_stack_from_file=False)          
            del I_tot        
    else:    
        #Find the path for the median file
        file_name_full = str(v.n_frames) + '_' + v.file_name0 + '.tif'
        path_med = os.path.join(v.dir_path_file_folder, 'MED_' + file_name_full)

        #Read the median file if if exists, otherwise create it and save it
        if os.path.exists(path_med):
            img_med = ti.read_tif_file(path_med)

        else:
            print(f'\tCreating a median file based on {n_frames} frames')
            tic = time.perf_counter()
            v_with_img = nd2_handling2.Video_waves(v.file_path, read_img_directly=True)
            img_full = v_with_img.img

            #Rescale to uint8
            img_full = ff.rescale_to_uint8(img_full) #Convert to unsigned integer 8-bit

            # Crop images according to cropping_area in the experiment spreadsheet
            if not all(i == 0 for i in v.crop):        
                img_full = ff.crop_image(img_full, v.crop)

            #Calculate the median along the Z-axis
            img_med = np.median(img_full, axis=0)  
            del img_full   
            ti.write_tif_file(path_med, img_med, photometric='minisblack')     
            toc = time.perf_counter()
            print(f'\tCreated a median image ({img_med.dtype}) based on {n_frames} frames in {toc - tic:0.2f} seconds\n')
    #Rescale to uint8
    img_med = ff.rescale_to_uint8(img_med)


    # if settings_general['show_imgs'] and settings_general['show_imgs_preprocessing'] and settings_general['show_imgs_preprocessing_part1_maskgen']:
    #     ff.show_img(img_med, title_text = f'Median image, {v.mag}, {v.p} mbar', cmap='gray', sz=3)
    print('!!!3 median img before rotation:', img_med.shape)
    img_med = rotate_img_and_show_alignment(img_med, d_vid, settings_general, angle_transform = df_vid.angle.values[0])
    print('!!!3 median img after rotation:', img_med.shape)
    return img_med

def create_med_file_small(img, path_med, df_vid, v, read_stack_from_file=True, path_file = ''):
    """
    Either loads the median file or creates it.
    img, uint8, 3D numpy array

    returns an img_med in float64
    """
    n_frames = img.shape[0]                              
    #Read the median file if if exists, otherwise create it and save it
    if os.path.exists(path_med):
        img_med = tifffile.imread(path_med)      
        print('Found median file, reading it: '+path_med+'\n')
    else:
        print(f'Did not find a median file, creating it.')    
        tic = time.perf_counter()
        if read_stack_from_file:
            img_full = ti.read_tif_file(path_file)
            img_full = ff.rescale_to_uint8(img_full) 
        else:
            img_full = img  
        print(f'\tBasing the new median file on {img_full.shape[0]} frames.')                                          
        #Calculate the median along the Z-axis
        img_med = np.median(img_full, axis=0)  
        del img_full #to save memory
                                        
        #Write to file                        
        ti.write_tif_file(path_med, img_med, photometric='minisblack')            
        toc = time.perf_counter()
        print(f'\tCreated a median image ({img_med.dtype}) based on {n_frames} frames in {toc - tic:0.2f} seconds\n')
    return img_med.astype(np.uint8)

def extract_parameters_for_mask_generation(df_vid, settings_device_type):

    #Update this function with  edge_margin = get_var_from_dic(d_vid, settings_device_type, 'edge_margin') when time...

    #General list of parameters
    area_min_factor = settings_device_type['area_min_factor']      
    factor_local_otsu = settings_device_type['factor_local_otsu']
    otsu_radius_mask_gen = settings_device_type['otsu_radius_mask_gen']
    factor_axis_major_length = settings_device_type['factor_axis_major_length']

    #The local list of parameters overrides the general list if present
    if 'area_min_factor' in df_vid.columns:            
        area_min_factor_exp = df_vid['area_min_factor'].values[0]   
        if not area_min_factor_exp == ' ':                
            if isinstance(area_min_factor_exp, (float, int)) and not math.isnan(area_min_factor_exp):
                print('\t\tLoading from nd2-spreadsheet: area_min_factor: ', area_min_factor_exp, ', ', type(area_min_factor_exp))
                area_min_factor = area_min_factor_exp
    if 'factor_local_otsu' in df_vid.columns:      
        factor_local_otsu_exp = df_vid['factor_local_otsu'].values[0]  
        if not factor_local_otsu_exp == ' ':
            try:                    
                factor_local_otsu_exp = float(factor_local_otsu_exp)
                if isinstance(factor_local_otsu_exp, (float, int))  and not math.isnan(factor_local_otsu_exp):  
                    print('\t\tLoading from nd2-spreadsheet: factor_local_otsu: ', factor_local_otsu_exp, ', ', type(factor_local_otsu_exp))
                    factor_local_otsu = factor_local_otsu_exp
            except:
                print(f'Error (ignored) it did not work to convert factor_local_otsu_exp (={factor_local_otsu_exp}, {type(factor_local_otsu_exp)}) to float')
        
    if 'otsu_radius_mask_gen' in df_vid.columns: 
        otsu_radius_mask_gen_exp = df_vid['otsu_radius_mask_gen'].values[0]
        if not otsu_radius_mask_gen_exp == ' ':
            try:
                otsu_radius_mask_gen_exp = float(otsu_radius_mask_gen_exp)
                if isinstance(otsu_radius_mask_gen_exp, (float, int)) and not math.isnan(otsu_radius_mask_gen_exp): 
                    print('\t\tLoading from nd2-spreadsheet: otsu_radius_mask_gen: ', otsu_radius_mask_gen_exp, ', ', type(otsu_radius_mask_gen_exp))
                    otsu_radius_mask_gen = otsu_radius_mask_gen_exp
            except:
                    print(f'Error (ignored) it did not work to convert otsu_radius_mask_gen (={otsu_radius_mask_gen_exp}, {type(otsu_radius_mask_gen_exp)}) to float')
    if 'factor_axis_major_length' in df_vid.columns: 
        factor_axis_major_length_exp = df_vid['factor_axis_major_length'].values[0]
        if not factor_axis_major_length_exp == ' ':
            try:
                factor_axis_major_length_exp = float(factor_axis_major_length_exp)
                if isinstance(factor_axis_major_length_exp, (float, int)) and not math.isnan(factor_axis_major_length_exp): 
                    print('\t\tLoading from nd2-spreadsheet: factor_axis_major_length: ', factor_axis_major_length_exp, ', ', type(factor_axis_major_length_exp))
                    factor_axis_major_length = factor_axis_major_length_exp
            except:
                    print(f'Error (ignored) it did not work to convert otsu_radius_mask_gen (={factor_axis_major_length_exp}, {type(factor_axis_major_length_exp)}) to float')
    return area_min_factor, factor_local_otsu, otsu_radius_mask_gen, factor_axis_major_length
import scipy.ndimage as ndimage
def gen_masks(img_med, v, df_vid, d_vid, settings_general, settings_device_type):
    """
    1. Generate masks of the non-fluid pixels (posts and side walls), 
    If bg mask and post mask do not already exist, generate them
        1. Threshold the median image. 
        2. Remove too small objects
        3. If large objects exists:
            1. Remove the side walls for videos where the wall is included to generate the post mask
            else the bg mask and the post mask are identical
            
    Can I avoid having to crop the image? If not: Crop all images thereafter...
    Get different parameters from different magnification...
    """    
    print('\tGenerating masks')
    if not img_med.dtype == 'uint8':
        raise ValueError(f'The median image data type is {img_med.dtype} and not uint8.')

    #Set paths
    path_mask_fluid = os.path.join(v.dir_pixelated_files, 'mask_fluid_' + v.file_name0 + '.tif')
    path_mask_posts = os.path.join(v.dir_pixelated_files, 'mask_posts_' + v.file_name0 + '.tif')
    
    #If the mask files do not exists, create them. 
    # This selection is overwritten depending on the parameter in the general settings dict.
    if not (os.path.exists(path_mask_fluid) & os.path.exists(path_mask_posts)) or settings_general['gen_masks_if_already_exists']:
  
        area_min_factor, factor_local_otsu, otsu_radius_mask_gen, factor_axis_major_length = extract_parameters_for_mask_generation(df_vid, settings_device_type)

        disp_details = False
        if 'disp_details_preprocessing' in settings_general:
            if settings_general['show_imgs'] and settings_general['disp_details_preprocessing']:
                disp_details = True

        #Create the post array mask
        img_med_cropped, mask_posts, mask_fluid, img_list, titles_list = pa.create_post_array_mask2(img_med, 
                                                                                                crop_border_width=1, 
                                                                                                area_min_factor=area_min_factor, 
                                                                                                otsu_radius = otsu_radius_mask_gen, 
                                                                                                factor_axis_major_length = factor_axis_major_length,
                                                                                                factor_local_otsu = factor_local_otsu,
                                                                                                dilate=False, 
                                                                                                disp_details = disp_details)
        
        if settings_general['show_imgs'] and settings_general['show_imgs_preprocessing'] and settings_general['show_imgs_preprocessing_part1_maskgen']:
            if settings_general['show_imgs_preprocessing']:
                for i,I in enumerate(img_list):
                    ff.show_img(I, title = str((i+1))+', '+titles_list[i], cmap='gray')
            else:
                ff.plot_sub_plots(img_list, titles_list, n_cols=3, sz=[10,15], fsz = 10, cmap='gray')                        
        if 'perform_only_up_to_after_post_array_mask_gen' in settings_general:
            if settings_general['perform_only_up_to_after_post_array_mask_gen']:
                sys.exit('perform_only_up_to_after_post_array_mask_gen')
        # Update masks based on finding the weighted centroid and convoluting them with a circle
        if settings_general['use_convolved_post_mask']:
            mask_posts, mask_fluid, list_centroids, list_xycoords, bw_post_centers = make_post_mask_based_on_circle_convolution(img_med_cropped, d_vid, mask_posts, settings_device_type,settings_general, disp_details = True)
        else: 
            if settings_general['show_imgs'] and settings_general['show_imgs_preprocessing']:
                out = ff.mask_gray_img(img_med_cropped, mask_posts, color=[255, 0, 0], alpha=0.5)
                ff.show_img(out, sz=20, title_text='Post mask (unconvolved) overlaid on the median image')

        if settings_general['show_imgs'] and settings_general['show_imgs_preprocessing'] and settings_general['show_imgs_preprocessing_part2']:
            #ff.show_img(mask_posts, sz=6, title_text='Mask of the posts')
            ff.show_img(mask_fluid, sz=10, title_text='Mask of the fluid')
            out = ff.mask_gray_img(img_med_cropped, mask_fluid, color=[255, 0, 0], alpha=0.2)
            if 'show_img_zoom_window' in settings_general:
                show_img_zoom_window = settings_general['show_img_zoom_window']
            else:
                show_img_zoom_window = [[150,200],[200,250]]
            ff.show_img(out, sz=10, title_text='Mask of the fluid (zoomed-in)', xlim=show_img_zoom_window[0], ylim=show_img_zoom_window[1])
        # Save to file
        if settings_general['save_imgs'] or settings_general['gen_masks_if_already_exists']:
            ti.write_tif_file(path_mask_fluid, mask_fluid)  
            ti.write_tif_file(path_mask_posts, mask_posts)  
    else:
        print('\tLoaded fluid mask')
        mask_fluid  = ti.read_tif_file(path_mask_fluid)
        # print('!!!4 mask_fluid.shape = ', mask_fluid.shape)
        mask_posts = ti.read_tif_file(path_mask_posts)

    conv_circle_radius = int(get_var_from_dic(d_vid, settings_device_type, 'conv_circle_radius'))
    circle = create_circle(radius=conv_circle_radius)        
    #Convolve the post centroid matrix with the circle object
    mask_conv_3 = signal.convolve2d(bw_post_centers, circle, mode='same')
    # mask_conv_3 = transform_img(mask_conv_3, d_vid)

    return mask_fluid, mask_posts, img_med_cropped, mask_conv_3

from skimage import measure
from scipy import signal

def create_mask_from_centroid_list(img, list_centroids):
    #Create a new image (cropped) to copy the coordinates from the adjusted objects data frame
    bw = np.zeros_like(img, dtype=bool)

    #Add the objects to the new image, pixel coordinate by pixel coordinate
    for i in range(len(list_centroids)):    
        x = list_centroids[i,0]
        y = list_centroids[i,1]
        bw[round(y),round(x)] = 1
    bw = bw.astype(bool)
    return bw

def find_post_objects(mask_posts, img_inv, settings_general, disp_details=True):
    print('Finding objects')
    #Extract the region properties out of the objects
    labels = measure.label(mask_posts) 
    props_table = measure.regionprops_table(labels, intensity_image=img_inv, properties=('area', 'coords', 'major_axis_length', 'weighted_centroid'))  #weighted_centroid (row, col)
    
    #Convert the list of the region properties to a pandas data frame
    df_props = pd.DataFrame(props_table)
    if len(df_props) == 0:
        raise ValueError(f'Could not find any objects in the image')
    areas = df_props.area.values
    med_area = np.median(areas)
    if med_area > 80:
        print(f'The median post object area is too high ({med_area}), instead using a value of 53 pixels')
        med_area = 53
    factor_merged = 1.5
    lim_lower_area = factor_merged*med_area
    lim_min_area = 0.1*factor_merged*med_area
    print(f'\tArea thresholds:')
    print(f'\t lim_lower_area = {lim_lower_area:.1f} pixels')
    print(f'\t lim_min_area = {lim_min_area:.1f} pixels')

    if settings_general['show_imgs'] and settings_general['show_imgs_preprocessing'] and settings_general['show_imgs_preprocessing_part1_maskgen']:
        fig = plt.figure(figsize=(10, 3))  # create a figure object
        ax = fig.add_subplot(1, 1, 1)  # create an axes object in the figure
        plt.title('Histogram over areas of identified post objects')
        plt.hist(areas, range=(0,200), bins=256)
        plt.axvline(lim_lower_area, color='r', label=f'merged objects defined as \nabove {round(lim_lower_area)} or {factor_merged} x median val.')
        plt.axvline(np.median(areas), linestyle='--', color='g', label=f'median = {np.round(np.median(areas),2)} pix')
        plt.axvline(lim_min_area, linestyle='--', color='b', label=f'Excluding objects with area lower than min = {round(lim_min_area,1)} pix')
        plt.xlabel('area [pix^2]')
        plt.ylabel('count')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') 
        plt.show()

    df_props_merged = df_props[df_props['area'] > lim_lower_area]
    df_props_single = df_props[(df_props['area'] <= lim_lower_area) & (df_props['area'] >= lim_min_area)]

    if disp_details:
        print(f'\tFound in total {len(df_props_merged)} merged objects and {len(df_props_single)} single objects') 
    if len(df_props_merged) > 0:
        #Create a new image (cropped) to copy the coordinates from the adjusted objects data frame
        bw_merged_objs = np.zeros_like(mask_posts)

        #Add the objects to the new image, pixel coordinate by pixel coordinate
        coords = df_props_merged['coords'].to_numpy()
        for obj in coords:    
            for [x,y] in obj:
                # bw_merged_objs[y,x] = 1  
                bw_merged_objs[x,y] = 1    

        if settings_general['show_imgs']:
            ff.show_img(bw_merged_objs, sz=5, title='Showing merged objects (in white)', cmap='gray')

        #Segment the merged posts and add them to a data frame
        labels_watershed = pa.watershed_segment_image(bw_merged_objs, disp_details = True, min_distance = 4)
        img_inv_merged = copy.deepcopy(img_inv)
        img_inv_merged[~bw_merged_objs] = 0
        props_table_watershed = measure.regionprops_table(labels_watershed, intensity_image=img_inv_merged, properties=('area', 'coords', 'major_axis_length', 'weighted_centroid'))  
        df_props_split = pd.DataFrame(props_table_watershed)

        #Add the single posts with the split posts data frames
        df_props_fin = df_props_single.append(df_props_split, ignore_index=True)
        
    else:
        df_props_fin = df_props_single
    if disp_details:
        print(f'\tFinal data frame contains {len(df_props_fin)} objects')     
    return df_props_fin

def transform_list_centroids_to_1D(list_centroids):
    list_centroids_1D = []
    for l in list_centroids:
        list_centroids_1D = list_centroids_1D + list(l)
    list_centroids_1D = np.array(list_centroids_1D)
    return list_centroids_1D

def too_close_to_border(img, point, edge_margin=5, edge_margin_x=-1):
    h,w = img.shape
    x = point[0]
    y = point[1]
    if edge_margin > edge_margin_x:
        edge_margin_x = edge_margin
    if (x < edge_margin_x) or (y < edge_margin) or (x>w-edge_margin) or (y > h-edge_margin):
        return True
    else:
        return False

def get_var_from_dic(d_vid, settings_device_type, var_name):
    var = settings_device_type[var_name]
    if var_name in d_vid:
        var = d_vid[var_name]
        if not var == ' ':
            try:                    
                var = float(var)
                if isinstance(var, (float, int))  and not math.isnan(var):  
                    print('\t\tLoading from nd2-spreadsheet: {var_name}: ', var, ', ', type(var))
                    return var
            except:
                print(f'Error (ignored) it did not work to convert {var_name} (={var}, {type(var)}) to float')
        else:
            raise ValueError(f'in-correct {var_name} value')
    else:
        return var
     

def remove_post_centers_close_to_border(img, d_vid, list_centroids, list_xycoords, settings_device_type):
    inxs_to_remove = []
    
    edge_margin = get_var_from_dic(d_vid, settings_device_type, 'edge_margin')
    edge_margin_x = get_var_from_dic(d_vid, settings_device_type, 'edge_margin_x')
    #Iterate all points
    for inx,point in enumerate(list_centroids): 
        if too_close_to_border(img, point, edge_margin, edge_margin_x=edge_margin_x):
            inxs_to_remove.append(inx)

    #Delete indices from numpy array and list

    list_centroids = np.delete(list_centroids, inxs_to_remove, axis=0)
    for index in sorted(inxs_to_remove, reverse=True):
        del list_xycoords[index]

    edge_margin = settings_device_type['edge_margin']
    print(f'\tRemoved {len(inxs_to_remove)} points that were too close to the edge (edge margin = {edge_margin})')
    return list_centroids, list_xycoords

def make_post_mask_based_on_circle_convolution(img_med_cropped, d_vid, mask_posts,settings_device_type, settings_general, disp_details = True):
    print('\tMaking post mask based on circle convolution')

    #Convert to uint8
    img_uint8 = img_med_cropped.astype(np.uint8)

    #Invert the original image
    img_inv = ~img_uint8

    #Extract the region properties out of the objects
    df_props_fin = find_post_objects(mask_posts, img_inv, settings_general, disp_details=disp_details)

    list_centroids, list_xycoords = extract_centroids_and_coords_from_df(df_props_fin)

    list_centroids, list_xycoords = remove_post_centers_close_to_border(img_med_cropped, d_vid, list_centroids, list_xycoords, settings_device_type)

    if settings_general['show_imgs'] and settings_general['show_imgs_preprocessing'] and settings_general['show_imgs_preprocessing_part2']:
        plot_centroids(list_centroids, settings_device_type, img=img_inv, extra_text = 'After pre-processing. Inverted median image.\n', show_img = True, single_list=True)
        plot_centroids(list_centroids, settings_device_type, extra_text = 'After pre-processing.\n', show_img = True, img=mask_posts, single_list=True)
    
    #Sort the centroid coordinates row-wise or column-wise
    if settings_general['align_rows_of_centroids']:
        list_centroids = align_list_centroids(list_centroids, settings_device_type, settings_general, sort_axis='x')
    if settings_general['align_columns_of_centroids']:
        list_centroids = align_list_centroids(list_centroids, settings_device_type, settings_general, sort_axis='y')

        if (settings_general['align_columns_of_centroids'] or settings_general['align_rows_of_centroids']) and settings_general['show_imgs'] and settings_general['show_imgs_preprocessing'] and settings_general['show_imgs_preprocessing_part2']:
            plot_centroids(list_centroids, settings_device_type, img=mask_posts, extra_text = 'After alignment of centroids.\n', show_img = True, single_list=False)
 
     #Set sort-axis, defaults to y, or sorting of the array by the rows.
    if 'sort_axis' in settings_general:
        sort_axis = settings_general['sort_axis']
    else:
        sort_axis = 'y'  

    #Sort the centroid coordinates row-wise or column-wise
    # list_centroids_sorted, list_object_coords_sorted = sort_centroid_coords_in_rows_or_cols_with_coords(list_centroids, list_xycoords, margin = settings_device_type['y_margin'], sort_axis=sort_axis)
    # # n = len(list_centroids_sorted)
    
    list_centroids_1D = transform_list_centroids_to_1D(list_centroids)

    #Create a new image (cropped) to copy the coordinates from the adjusted objects data frame
    bw_post_centers = create_mask_from_centroid_list(img_med_cropped, list_centroids_1D)

    #Show the sorting of the coordinates
    if settings_general['show_imgs'] and settings_general['show_imgs_preprocessing'] and settings_general['show_imgs_preprocessing_part2']:
        plot_centroids(list_centroids, settings_device_type, show_img = False, img=bw_post_centers, show_ordering_by_color=True)

    #Show post centroids
    if settings_general['show_imgs'] and settings_general['show_imgs_preprocessing'] and settings_general['show_imgs_preprocessing_part2']:
        plot_centroids(list_centroids_1D, settings_device_type, img=img_uint8, show_img = False, single_list=True)
    if settings_general['show_imgs'] and settings_general['show_imgs_preprocessing_part2']:
        if 'show_img_zoom_window' in settings_general:
            show_img_zoom_window = settings_general['show_img_zoom_window']
        else:
            show_img_zoom_window = [[150,200],[200,250]]
        plot_centroids(list_centroids_1D, settings_device_type, img=img_uint8, extra_text = 'On-top of median image\n', show_img = True, single_list=True, xlim=[150,200], ylim=[20,70])
        plot_centroids(list_centroids_1D, settings_device_type, img=img_uint8, extra_text = 'On-top of median image\n', show_img = True, single_list=True, xlim=show_img_zoom_window[0], ylim=show_img_zoom_window[1])

    if settings_device_type['binning'] == '2x2':
        conv_circle_radius = int(get_var_from_dic(d_vid, settings_device_type, 'conv_circle_radius')) 
        circle = create_circle(radius=conv_circle_radius)   
    else:
        #Create a 2D circle object based on the input area
        areas_fin = df_props_fin.area.values
        area_med = np.median(areas_fin)
        circle, radius = create_circle_based_on_area(area_med)   

    #Convolve the post centroid matrix with the circle object
    mask_conv = signal.convolve2d(bw_post_centers, circle, mode='same')
    mask_conv=mask_conv.astype(bool)

    #Remove all posts or part of posts that are outside the field of view (only for rotated videos)
    mask_frame = (img_uint8==0).astype(bool)
    #Remove noise in case it exists:
    mask_frame = pa.threshold_based_on_area(mask_frame, area_min_factor = 100, dir='min', based_on_median=False)

    #Mask for inside the Field of Fiew (inside the frame)
    mask_inside_FOV = ~mask_frame
    # ff.show_img(mask_conv, sz=6, title_text='mask_conv')

    area_min_factor = 200
    invert_img = True
    mask_sidewalls, img_list, titles_list = pa.find_side_walls(img_med_cropped, area_min_factor=area_min_factor, gauss_sigma=1, otsu_threshold_factor=settings_device_type['otsu_threshold_factor_sidewalls'], invert_img = invert_img)
    if settings_general['show_imgs'] and settings_general['show_imgs_preprocessing'] and settings_general['show_imgs_preprocessing_part1_maskgen']:
        ff.show_img(mask_sidewalls, title_text='mask of the side walls')
    if settings_general['show_imgs'] and settings_general['show_imgs_preprocessing_part1_maskgen']:
        if 'show_side_wall_pre_processing' in settings_general:
            if settings_general['show_side_wall_pre_processing']:
                print('Side wall pre-processing (turn off show_side_wall_pre_processing in settings_general if you do not want to see this)')
                for i,I in enumerate(img_list):
                    ff.show_img(I, title = str((i+1))+', '+titles_list[i])
    mask_conv = mask_conv & mask_inside_FOV
    print(mask_inside_FOV.shape, mask_sidewalls.shape)
    mask_fluid = ~mask_conv & mask_inside_FOV & ~mask_sidewalls
    if settings_general['show_imgs'] and settings_general['show_imgs_preprocessing'] and settings_general['show_imgs_preprocessing_part2']:
        ff.show_img(mask_conv, title_text = f'Post centers convolved \nwith a circle with r = {radius} pix', sz=10, cmap='gray')        
        out = ff.mask_gray_img(img_med_cropped, mask_conv, color=[255, 0, 0], alpha=0.2)
        ff.show_img(out, title_text='Circles overlaid on the post array mask')
        
    return mask_conv, mask_fluid, list_centroids, list_xycoords, bw_post_centers


import math
import matplotlib.pyplot as plt
from skimage import draw

def create_circle(radius):
    """[Creates a 2D circle object]

    Args:
        radius ([nbr]): [radius of the circle (has to be odd)]

    Returns:
        circle [2d boolean array]: [circle object]
    """
    print(f'\t\t Creating circle with radius {radius}')
    arr_width = 2*int(radius)-1
    circle = np.zeros((arr_width, arr_width))
    #print(f'array width = {arr_width}, radius = {radius}')
    rr, cc = draw.disk((radius-1, radius-1), radius=radius, shape=circle.shape)
    circle[rr, cc] = 1
    return circle

def create_circle_based_on_area(area_med):
    """[Creates a 2D circle object based on the input circle area.]

    Args:
        area_med ([nbr]): [assumed area of the circle]

    Returns:
        circle [bool array]: [circle with radius r]
    """
    radius = round(math.sqrt(area_med/math.pi))
    if radius % 2 == 0:
        radius = radius + 1
    if radius == 3:
        radius = radius + 2

    circle = create_circle(radius)
    show_circle = False

    if show_circle:
        fig = plt.figure(figsize=(3,3))
        ax = fig.add_subplot(1, 1, 1)
        plt.imshow(circle, cmap='gray')
        plt.title(f'circle with radius = {radius}')
        # plt.show()
    return circle, radius

def create_rot_rectangle_(w,h,angle, show_img=False):
    #Create rotated rectangle mask

    rect = np.ones([w,h]).astype(bool)
    rect = Image.fromarray(rect) #Convert to a PIL Image from a numpy array (must be 8-bit unsigned)
    rect_rot = rect.rotate(angle, resample=Image.BILINEAR, expand=True, translate=None, fillcolor='black') 
    if show_img:
        ff.show_img(rect_rot, cmap='gray', sz=3, title=f'Rotated Rectangle Mask ({w}x{h})')
    return np.array(rect_rot)

def extract_centroids_and_coords_from_df(df):
    n = df.shape[0]
    list_centroids = np.zeros([n,2])
    list_xycoords = []
    df_dict = df.to_dict('records')
    for i, row in enumerate(df_dict):       
        coords = row['coords']
        coords2 = np.flip(coords,axis=1)
        list_xycoords.append(coords2)
        list_centroids[i,:] = [row['weighted_centroid-1'],row['weighted_centroid-0']]    #input as x,y        
    return list_centroids, list_xycoords

def align_list_centroids(list_centroids, settings_device_type, settings_general, sort_axis='x'):    
    """[Sort the centroid coordinates either row-wise or column wise. 
        Row-wise implies from top to bottom, left to right. For this option, select axis as y.
        Column-wise implies from left to right, op to bottom. For this option, select axis as x.
    
    ]

    Args:
        list_centroids ([list of coords]): [List of coordinates of the centroids]
        list_xycoords ([list of lists]): [List of the corresponding xy-coordinates of all posts, same order as list_centroids]
        margin ([int]): [Margin to neighboring row]
        axis (string): [Axis to perform the sorting in. Either x or y. x means column-wise sorting and y means row-wise sorting.]

    Returns:
        [list_centroids_sorted, : Sorted list 
        list_xycoords_sorted : Sorted list 
        ]
    """
    if sort_axis == 'x': #columns
        sort_dir = 'cols'
        margin = settings_device_type['x_margin']
    elif sort_axis == 'y': #rows
        sort_dir = 'rows'
        margin = settings_device_type['y_margin']
    print(f'\t Aligning centroids. sort direction = {sort_dir}, margin = {margin} pix')
    
    device_type = settings_device_type['device_type']
    sort_axis0 = settings_general['sort_axis']
    if sort_axis != sort_axis0:
        print(f'\t\tThe selected sort axis ({sort_axis}) is new from previously extracted ({sort_axis0}). Re-sorting the centroids with the new axis')

    # sort_axis = settings_general['sort_axis']
    #Sort the centroid coordinates row-wise or column-wise
    # list_centroids_sorted = sort_only_centroid_coords_by_rows(list_centroids, settings_device_type, margin = margin, sort_dir=sort_dir)
    #plot_centroids(list_centroids_sorted, title = 'After first sorting', show_img = False, single_list=False)

    list_centroids_sorted_1D = np.vstack(list_centroids)
    #plot_centroids(list_centroids_sorted_1D, title = 'After first unraveling', show_img = False, single_list=True)
    # print('len(list_centroids_sorted_1D)', len(list_centroids_sorted_1D))
    # print('first object: ', list_centroids_sorted_1D[0])

    #Sort the centroid coordinates row-wise or column-wise
    if sort_dir == 'rows':
        list_centroids_sorted = sort_only_centroid_coords_by_rows(list_centroids_sorted_1D, settings_device_type, margin = margin)
    elif sort_dir == 'cols':
        list_centroids_sorted = sort_only_centroid_coords_by_columns(list_centroids_sorted_1D, settings_device_type, margin = margin)
    print(f'\t\t found {len(list_centroids_sorted)} {sort_dir}')
    if settings_general['show_imgs'] and settings_general['show_imgs_preprocessing']:
        plot_centroids(list_centroids_sorted, settings_device_type, show_img = False, show_ordering_by_color=True, sort_dir=sort_dir, axis_on=True, disp_side_text=True)

    #Average the y-positions of the columns or the x-positions in the rows
    if sort_dir == 'cols':
        if device_type == 'Q':
            for i,col in enumerate(list_centroids_sorted):  
                x_med = np.median(col[:,0])
                list_centroids_sorted[i][:,0] = x_med
        elif device_type == 'H':
            # for i in range(1, len(list_centroids_sorted), 1):  
            #     col = list_centroids_sorted[i]
            #     x_med = np.median(col[:,0])  
            #     print(f'i={i}, x_med = {x_med:.1f}') 
            for i in range(0, len(list_centroids_sorted), 1):  
                col = list_centroids_sorted[i]
                # if i==0:
                    # print(f'column {i}, first obj x,y = {col[0,0]:.1f},{col[0,1]:.1f}')  
                    # for j in range(0, len(col), 1):  
                    #     cell = col[j]
                        # print(f'center {j}, first obj x,y = {cell[0]:.1f},{cell[1]:.1f}')  
                    # a=5   
                    #      
            #Take the median of every other position
            for i in range(0, len(list_centroids_sorted), 2):  
                col = list_centroids_sorted[i]
                x_med = np.median(col[:,0])
                list_centroids_sorted[i][:,0] = x_med  
                # print(f'col {i},\t x_med = {x_med:.1f}') 
            for i in range(1, len(list_centroids_sorted), 2):  
                col = list_centroids_sorted[i]
                x_med = np.median(col[:,0])
                list_centroids_sorted[i][:,0] = x_med     
                # print(f'col {i},\t x_med = {x_med:.1f}')   

            # for i in range(1, len(list_centroids_sorted), 1):  
            #     col = list_centroids_sorted[i]
                # print(f'column {i}, first obj x,y = {col[0,0]:.1f},{col[0,1]:.1f}') 
    elif sort_dir == 'rows':
        for i,row in enumerate(list_centroids_sorted):  
            y_med = np.median(row[:,1])
            list_centroids_sorted[i][:,1] = y_med        

    # list_centroids_sorted_1D = np.vstack(list_centroids_sorted)
    # print('len(list_centroids_sorted_1D)', len(list_centroids_sorted_1D))
    # print('first object: ', list_centroids_sorted_1D[0])
    
    return list_centroids_sorted

def sort_centroids_sublist(list_centroids, sort_dir):
    """Sorts a list of xy-coordinates based on the x or y depending on the sort_axis"""
    if sort_dir == 'rows':
        vals = list_centroids[:,0] #All the y-coordinates in a list
    if sort_dir == 'cols':
        vals = list_centroids[:,1] #All the x-coordinates in a list
    inxs = np.argsort(vals)
    list_centroids_sorted = np.array(list_centroids)[inxs]
    return list_centroids_sorted

def sort_centroid_coords_in_rows_or_cols_with_coords(list_centroids, list_xycoords, margin=5, sort_dir = 'rows'):    
    """[Sort the centroid coordinates either row-wise or column wise. 
        Row-wise implies from top to bottom, left to right. For this option, select axis as y.
        Column-wise implies from left to right, op to bottom. For this option, select axis as x.
    
        1. for loop:
            1. Find the uppermost left object (**x_m, y_m**)
            1. Find all objects within the same y-region, **y_m+-y_margin**
            1. Sort this list according to x-position
            1. Add to a the list
            1. Remove these objects from the old list.
            1. Repeat until no more objects left on the list 
    ]

    Args:
        list_centroids ([list of coords]): [List of coordinates of the centroids]
        list_xycoords ([list of lists]): [List of the corresponding xy-coordinates of all posts, same order as list_centroids]
        margin ([int]): [Margin to neighboring row]
        axis (string): [Axis to perform the sorting in. Either x or y. x means column-wise sorting and y means row-wise sorting.]

    Returns:
        [list_centroids_sorted, : Sorted list 
        list_xycoords_sorted : Sorted list 
        ]
    """
    if sort_dir not in ['rows', 'cols']:
        raise ValueError(f'Wrong axis input, available options are rows or cols y, not {sort_dir}')     

    cc_copy = copy.deepcopy(list_centroids)
    list_centroids_sorted = []
    list_xycoords_sorted = []
    while len(cc_copy) > 0:    
        inxs = find_indxs_of_row_or_col(cc_copy, margin=margin, sort_dir=sort_dir)
        centroids_inxs = list_centroids[inxs]
        centroids_inxs = sort_centroids_sublist(centroids_inxs, sort_dir=sort_dir)
        #Add the minimum x or y-list to the sorted list
        list_centroids_sorted.append(centroids_inxs)
        
        #Add the corresponding xy-coords
        list_xycoords_sub = []
        for inx in inxs:
            list_xycoords_sub.append(list_xycoords[inx])
        list_xycoords_sorted.append(list_xycoords_sub)
        #print(cc_copy[inxs_row])
        cc_copy[inxs,:] = 1000   
        if np.all(cc_copy == 1000):
            break
    return list_centroids_sorted, list_xycoords_sorted

def sort_centroids_list(list_centroids, settings_general, settings_device_type):
    if 'sort_axis' in settings_general:
        sort_axis = settings_general['sort_axis']
    else:
        sort_axis = 'y'  

    if sort_axis == 'x':
        margin =  settings_device_type['x_margin']
    elif sort_axis == 'y' and settings_device_type['device_type'] == 'H':
        margin =  settings_device_type['y_margin']
    else:
        raise ValueError('This option is not eligible')
    

    if 'sort_dir' in settings_general:
        sort_dir = settings_general['sort_dir']
    else:
        sort_dir = 'cols' 

    #Convert list of dz centers into a 1D-array
    list_centroids_1D = np.vstack(list_centroids)

    #Sort the centroid coordinates row-wise or column-wise
    if sort_dir == 'rows':
        list_centroids_sorted = sort_only_centroid_coords_by_rows(list_centroids_1D, settings_device_type, margin = margin)
    elif sort_dir == 'cols':
        list_centroids_sorted = sort_only_centroid_coords_by_columns(list_centroids_1D, settings_device_type, margin = margin)
    else:
        raise ValueError(f'Wrong axis input, available options are rows or cols y, not {sort_dir}')   

    return list_centroids_sorted

def find_row_coords(list_centroids_1D, settings_device_type, margin=5):    
    """[
    """
    sort_dir = 'rows'
    print('\t\tSorting centroids list by row')
    margin = settings_device_type['y_margin'] 

    cc_copy = copy.deepcopy(list_centroids_1D)
    list_rows_mean_y_pos = []

    #Iterate all centroids. 
    while len(cc_copy) > 0:    
        # 1. Find the uppermost left object (**x_m, y_m**)
        # 2. Find all objects within the same y-region, **y_m+-y_margin**
        inxs_row = find_indxs_of_row_or_col(cc_copy, margin=margin, sort_dir = sort_dir)
        y_rows = list_centroids_1D[inxs_row][:,1]
        y_rows_mean_y_pos = np.mean(y_rows)
        list_rows_mean_y_pos.append(y_rows_mean_y_pos)
        cc_copy[inxs_row,:] = 1000   
        if np.all(cc_copy == 1000):
            break
    return list_rows_mean_y_pos

def find_column_coords(list_centroids_1D, settings_device_type, margin=5):    
    """[
    """
    sort_dir = 'cols'
    print('\t\tSorting centroids list by column')
    margin = settings_device_type['x_margin'] 

    cc_copy = copy.deepcopy(list_centroids_1D)
    list_rows_mean_x_pos = []

    #Iterate all centroids. 
    while len(cc_copy) > 0:    
        # 1. Find the uppermost left object (**x_m, y_m**)
        # 2. Find all objects within the same y-region, **y_m+-y_margin**
        inxs_row = find_indxs_of_row_or_col(cc_copy, margin=margin, sort_dir = sort_dir)
        x_cols = list_centroids_1D[inxs_row][:,0]
        x_cols_mean_x_pos = np.mean(x_cols)
        list_rows_mean_x_pos.append(x_cols_mean_x_pos)
        cc_copy[inxs_row,:] = 1000   
        if np.all(cc_copy == 1000):
            break
    return list_rows_mean_x_pos

def sort_only_centroid_coords_by_rows(list_centroids_1D, settings_device_type, margin=5):    
    """[Sort the centroid coordinates either row-wise or column wise. 
        Row-wise implies from top to bottom, left to right. For this option, select axis as y.
        Column-wise implies from left to right, op to bottom. For this option, select axis as x.
    
        1. for loop:
            1. Find the uppermost left object (**x_m, y_m**)
            1. Find all objects within the same y-region, **y_m+-y_margin**
            1. Sort this list according to x-position
            1. Add to a the list
            1. Remove these objects from the old list.
            1. Repeat until no more objects left on the list 
    ]

    Args:
        list_centroids ([list of coords]): [List of coordinates of the centroids]
        list_xycoords ([list of lists]): [List of the corresponding xy-coordinates of all posts, same order as list_centroids]
        margin ([int]): [Margin to neighboring row]
        axis (string): [Axis to perform the sorting in. Either x or y. x means column-wise sorting and y means row-wise sorting.]

    Returns:
        [list_centroids_sorted, : Sorted list 
        list_xycoords_sorted : Sorted list 
        ]
    """
    sort_dir = 'rows'
    print('\t\tSorting centroids list by row')
    margin = settings_device_type['y_margin'] 

    cc_copy = copy.deepcopy(list_centroids_1D)
    list_centroids_sorted = []

    #Iterate all centroids. 
    while len(cc_copy) > 0:    
        # 1. Find the uppermost left object (**x_m, y_m**)
        # 2. Find all objects within the same y-region, **y_m+-y_margin**
        inxs_row = find_indxs_of_row_or_col(cc_copy, margin=margin, sort_dir = sort_dir)
        list_centroids_sorted.append(list_centroids_1D[inxs_row])
        cc_copy[inxs_row,:] = 1000   
        if np.all(cc_copy == 1000):
            break
    return list_centroids_sorted

def sort_only_centroid_coords_by_columns(list_centroids_1D_org, settings_device_type, margin=5):    
    """[Sort the centroid coordinates either row-wise or column wise. 
        Row-wise implies from top to bottom, left to right. For this option, select axis as y.
        Column-wise implies from left to right, op to bottom. For this option, select axis as x.
    
        1. for loop:
            1. Find the uppermost left object (**x_m, y_m**)
            1. Find all objects within the same y-region, **y_m+-y_margin**
            1. Sort this list according to x-position
            1. Add to a the list
            1. Remove these objects from the old list.
            1. Repeat until no more objects left on the list 
    ]

    Args:
        list_centroids ([list of coords]): [List of coordinates of the centroids]
        list_xycoords ([list of lists]): [List of the corresponding xy-coordinates of all posts, same order as list_centroids]
        margin ([int]): [Margin to neighboring row]
        axis (string): [Axis to perform the sorting in. Either x or y. x means column-wise sorting and y means row-wise sorting.]

    Returns:
        [list_centroids_sorted, : Sorted list 
        list_xycoords_sorted : Sorted list 
        ]
    """
    sort_dir = 'cols'
    list_centroids_1D = copy.deepcopy(list_centroids_1D_org)
    list_centroids_sorted = []
    #Iterate all centroids. 
    while len(list_centroids_1D) > 0:    
        # 1. Find the uppermost left object (**x_m, y_m**)
        # 2. Find all objects within the same y-region, **y_m+-y_margin**

        inxs_row = find_indxs_of_row_or_col(list_centroids_1D, margin=margin, sort_dir=sort_dir)
        list_centroids_sorted.append(list_centroids_1D[inxs_row])
        list_centroids_1D[inxs_row,:] = 1000   
        if np.all(list_centroids_1D == 1000):
            break
    return list_centroids_sorted


def plot_centroids(list_centroids, settings_device_type, title='', extra_text = '', show_img = False, show_ordering_by_color=False,img=0, single_list=False, zoom_in = False, window=0, axis_on=False, xlim=[-1,-1], ylim=[-1,-1], sort_dir='rows', disp_side_text=False):
    """[Plots centroids from a list]

    Args:
        list_centroids ([type]): [description]
        title (str, optional): [description]. Defaults to ''.
        show_img (bool, optional): [description]. Defaults to False.
        img (int, optional): [description]. Defaults to 0.
    """
    n = len(list_centroids)
    if not single_list and n > 100:
        raise ValueError('list_centroids should be 2D, not ID')
    if sort_dir == 'cols':
        sort_dir_text = 'columns'
        sort_dir_other = 'rows'         
    elif sort_dir == 'rows':
        sort_dir_other = 'columns'
        sort_dir_text = sort_dir
    if show_ordering_by_color:
        title = extra_text+f'Detected {n} {sort_dir_text} of centroids ({sort_dir_other} marker-labelled and {sort_dir_text} color-labelled)'
    else:
        if single_list:
            title = extra_text+f'Detected {n} of centroids'
        else:
            title = extra_text+f'Detected {n} {sort_dir_text} of centroids'
    if show_ordering_by_color:
        # lengths = [len(l) for l in list_centroids]
        # max_length = np.max(np.array(lengths))
        # n_colors = max_length
        n_colors = n
        pal = sns.color_palette("prism",n_colors)
    filled_markers = ['o', 'v', '.', '<', '>', '.', 's', 'p', '.', 'h', 'H', '.', 'd', 'P', '.']
    sz=10
    fig = plt.figure(figsize=(sz, sz))  # create a figure object
    ax = fig.add_subplot(1, 1, 1)  # create an axes object in the figure
    if show_img:
        if zoom_in:
            img = img[window[0][0]:window[0][1], window[1][0]:window[1][1]]
        ax.imshow(img, cmap='viridis') 
    else: 
        plt.gca().invert_yaxis()
    if not axis_on:
        ax.axis('off')
    if single_list:
        x = np.round(list_centroids[:,0])
        y = np.round(list_centroids[:,1])
        ax.plot(x, y, '.', markersize=3, color='red')
    else:
        # x_max = np.array([np.max(np.array(row)) for row in list_centroids])
        # k=0
        for i,list_xy in enumerate(list_centroids):
            # list_x = np.round(list_xy[:,0])
            # list_y = np.round(list_xy[:,1])
            if show_ordering_by_color:
                color = pal[i]
            else:
                color = 'red'
            # marker = filled_markers[k]
            k=0
            for j,xy in enumerate(list_xy):
                x = xy[0]
                y = xy[1]
                if show_ordering_by_color:
                    if settings_device_type['device_type'] == 'H':
                        if i%2 == 0:
                            if j%2 == 0:
                                marker = '.'
                            else:
                                marker = '*'
                        else:
                            if j%2 == 0:
                                marker = '<'
                            else:
                                marker = 'o'
                    else:
                        marker = filled_markers[k]
                else:
                    marker = '.'

                ax.plot(x, y, marker, markersize=3, color=color)
                if disp_side_text:
                    if j == len(list_xy)-1:
                        if sort_dir == 'cols':
                            rot = -90
                            plt.text(x,y+30, f'{len(list_xy)} cols', rotation=rot, fontsize='x-small')
                        elif sort_dir == 'rows':
                            rot = 0
                            plt.text(x+10,y+3, f'{len(list_xy)} rows', rotation=rot)
                    if i == len(list_centroids)-1:
                        if sort_dir == 'cols':
                            plt.text(x+10,y+3, f'row {j}', fontsize='x-small')
                        elif sort_dir == 'rows':
                            rot = -90
                            plt.text(x+10,y+30, f'col {j}', rotation=rot)
                    if i == len(list_centroids)-2:
                        if sort_dir == 'cols':
                            plt.text(x+10,y+3, f'row {j}', fontsize='x-small')
                        elif sort_dir == 'rows':
                            rot = -90
                            plt.text(x,y+30, f'col {j}', rotation=rot)
                k=k+1
                if k >= len(filled_markers):
                    k=0
    plt.style.use('general') 
    plt.title(title)
    if not ylim[-1] == -1:
        plt.ylim(ylim)
    if not xlim[-1] == -1:
        plt.xlim(xlim)
    plt.show()  


from skimage import measure
def find_post_centers(mask_posts, img_med, settings_general, settings_device_type, df_vid, disp_details = True):
    """[1. Find the xy-coordinates of the post centroids using weighted centroids in regionprops
        (1.5 Generate error if the post centroids found are too close to the border)
        2. Sort the list of post centroid coordinates into an array of rows and columns]

    Args:
        mask_posts ([type]): [description]
        img_med ([type]): [description]
        settings_general ([type]): [description]
        df_vid ([type]): [description]
        disp_details (bool, optional): [description]. Defaults to True.

    Returns:
        [type]: [description]
    """
    print('\tFinding post centers')
    #Invert the original image
    img_uint8 = img_med.astype(np.uint8)
    img_inv = ~img_uint8

    #Segment the merged posts and add them to a data frame
    labels_watershed = pa.watershed_segment_image(mask_posts, disp_details = True, min_distance = 4)
    img_inv_merged = copy.deepcopy(img_inv)
    img_inv_merged[~mask_posts] = 0
    props_table_watershed = measure.regionprops_table(labels_watershed, intensity_image=img_inv, properties=('area', 'coords', 'major_axis_length', 'weighted_centroid'))  
    df_props_split = pd.DataFrame(props_table_watershed)

    list_centroids, list_xycoords = extract_centroids_and_coords_from_df(df_props_split)

    if settings_general['show_imgs'] and settings_general['show_imgs_preprocessing']:
        sz=10
        fig = plt.figure(figsize=(sz, sz))  # create a figure object
        ax = fig.add_subplot(1, 1, 1)  # create an axes object in the figure
        ax.axis('off')
        imgplot = ax.imshow(mask_posts, cmap='viridis') 
        for i,c in enumerate(list_centroids):
            [x0,y0] = list_centroids[i,:]
            ax.plot(x0, y0, '.r', markersize=2)
        plt.style.use('general') 
        plt.title('Inverted median image with detected \npost centroid marked with red')
        plt.show()   

    #Set sort-direction, defaults to y, or sorting of the array by the rows.
    if 'sort_dir' in settings_general:
        sort_dir = settings_general['sort_dir']
    else:
        sort_dir = 'cols' 
    #Sort the centroid coordinates row-wise or column-wise
    # list_centroids_sorted = sort_centroids_list(list_centroids, settings_general, settings_device_type)
    list_centroids_sorted, list_object_coords_sorted = sort_centroid_coords_in_rows_or_cols_with_coords(list_centroids, list_xycoords, margin = settings_device_type['y_margin'], sort_dir=sort_dir)
    n = len(list_centroids_sorted)
    #Show the sorting of the coordinates
    if settings_general['show_imgs'] and settings_general['show_imgs_preprocessing']:
        bw_post_centers = create_mask_from_centroid_list(img_med, list_centroids)
        plot_centroids(list_centroids_sorted, settings_device_type, show_img = False, img=bw_post_centers, show_ordering_by_color=True)
    return list_centroids_sorted, list_object_coords_sorted

def find_indxs_of_row_or_col(list_centroids_1D, margin=5, sort_dir = 'rows'):
    """[Finds the lowest y coordinate row in the array]

    Args:
        list_centroids ([type]): [description]
        margin (int, optional): [Margin - Safe distance from a post center to include all the posts in a row but not to include posts from the next row. 
                                This depends on the camera sensor size, the magnification and dimensions of the array
                                ]. Defaults to 5.
        axis (string): [Axis to perform the sorting in. Either x or y. x means column-wise sorting and y means row-wise sorting.]

    Returns:
        [inxs_row]: [List of indices in list_centroids for the lowest row]
    """
    if sort_dir == 'rows':
        ys = list_centroids_1D[:,1] #All the y-coordinates in a list
    elif sort_dir == 'cols':
        ys = list_centroids_1D[:,0] #All the x-coordinates in a list
    else:
        raise ValueError(f'Wrong axis input, available options are rows or cols, not {sort_dir}')

    inx_min = np.argmin(ys) #Find the index of the lowest y-value

    #Set the min and max margins of what y-values be acceptable for being in the same row. 
    y_min = ys[inx_min] - margin
    y_max = ys[inx_min] + margin    
    #print(f'min y-value = {np.round(ys[inx_min],2)}, y_min = {np.round(y_min,2)}, y_max = {np.round(y_max,2)}')

    #Find all y-values that are in the same row as the lowest y-value
    inxs = np.nonzero((ys > y_min) & (ys < y_max))[0]
    return inxs 


def visualize_regions(img_temp, v, xycoords_regions, list_centroids_sorted=0, plot_post_centers=True, post_center_color='black', save_img = False, with_posts = False, with_bg = False, zoom_in = False, window=0,title='', xlim=[-1,-1], ylim=[-1,-1], random_colors=False, axis_on=False, hollow_rect=False):
    """[Visualize the regions defining the superpixels]

    Args:
        img_temp ([type]): [description]
        xycoords_regions ([type]): [description]
        list_centroids_rows (int, optional): [description]. Defaults to 0.
        plot_post_centers (bool, optional): [description]. Defaults to True.
        post_center_color (str, optional): [description]. Defaults to 'black'.
    """
    if len(img_temp.shape) != 2:
        raise ValueError(f'The dimension of the image should be 2D, not {len(img_temp.shape)}D')
    if random_colors:
        n_colors = 30    
        pal = sns.color_palette("Spectral",n_colors)
    else:
        color = (0.6872741253364091, 0.07896962706651288, 0.2748173779315648) #red color

    h = img_temp.shape[0]
    w = img_temp.shape[1]
    if with_bg:
        if img_temp.dtype != 'bool':
            p = ff.calc_percentiles(img_temp, min=0.5, max=99.5)
            img_temp = exposure.rescale_intensity(img_temp, in_range=(p[0], p[1]), out_range=(0,255)).astype(np.uint8)

        img_regions_color = np.dstack((img_temp,img_temp,img_temp)).astype(np.uint8)
        
        #Make the bg image gray if it was boolean, for using a boolean mask
        if img_temp.dtype == 'bool':
            img_regions_color = img_regions_color*100
    else:
        img_regions_color = np.zeros([h,w,3],dtype=np.uint8)

    sz=8
    fig = plt.figure(figsize=(sz, sz))  # create a figure object
    ax = fig.add_subplot(1, 1, 1)  # create an axes object in the figure
    if axis_on:
        ax.axis('off')

    for j,row in enumerate(xycoords_regions):
        for i,p in enumerate(row):
            if len(p.shape) == 2:
                x = p[:,0]
                y = p[:,1]                 
                if random_colors:
                    i_col = random.randint(0, n_colors-1)
                    color = pal[i_col]
                
                if hollow_rect:
                    x_min = np.min(x)-0.5
                    y_min = np.min(y)-0.5
                    w = x[-1] - x[0] + 1
                    h = y[-1] - y[0] + 1
                    # Create a Rectangle patch
                    rect = patches.Rectangle((x_min, y_min), w, h, linewidth=1, edgecolor='r', facecolor='none')

                    # Add the patch to the Axes
                    ax.add_patch(rect)
                else:
                    img_regions_color[y,x,:] = np.array(color)*256   
                if plot_post_centers:
                    row_posts = list_centroids_sorted[j]
                    p_posts = row_posts[i]
                    plt.scatter(p_posts[0],p_posts[1], marker='.', color=post_center_color, s=3)

    imgplot = ax.imshow(img_regions_color) 
    plt.style.use('general') 
    if len(title) > 0:
        plt.title(title)
    if not ylim[-1] == -1:
        plt.ylim(np.flip(ylim))
    if not xlim[-1] == -1:
        plt.xlim(xlim)
    plt.show()
    
    if save_img:
        if with_posts:
            file_name = 'regions_visualization.tif'
        else:
            file_name = 'regions_visualization_with_posts.tif'
        path_file = os.path.join(v.dir_pixelated_files, file_name)
        ti.write_tif_file(path_file, img_regions_color, photometric='minisblack')  

def calc_range(x0, dist, max_val):
    """Calculates a range with +-dist/2 (either in x or y) from x0
    If the coordinates is close to the border, the limit in that direction is set to the border.
    """
    x_range = [round(x0-dist/2),round(x0+dist/2)]
    if x_range[0] < 0:
        x_range[0] = 0
    if x_range[1] > max_val:
        x_range[1] = max_val      
    return x_range

def calc_closest_post_of_each_pixel(list_centroids_rows, list_object_coords_rows, mask_fluid, dist_max=25):
    """Calculate which post is closest for each pixel.
        1. Loop all pixels in the fluid (use **mask_fluid** and convert to a 1D-list). 
                1. Find out which nearby posts are closest
                    1. Filter the post centroid_coords list by looking at ca 3 x lambda in all directions
                    2. Find the closest posts using Euclidian Distance calculation
                    1. Add to the list xycoords_regions
    
    """
    #Create new empty list of lists for the pixels
    xycoords_regions = []
    xycoords_regions_with_posts = []
    for j,row in enumerate(list_centroids_rows):
        empty_list = []
        for i,p in enumerate(row):
            p2 = np.array([round(p[0]),round(p[1])])
            empty_list.append([])
        xycoords_regions.append(empty_list)
        xycoords_regions_with_posts.append(list_object_coords_rows[j])

    #Get two lists of the fluid pixels corresponding to x and y-coordinates
    [y_coords_fluid,x_coords_fluid] = np.where(mask_fluid)

    y0_list = np.array([l[0][1] for l in list_centroids_rows])

    h=mask_fluid.shape[0]
    w=mask_fluid.shape[1]

    # # Loop all the pixels in the fluid
    for i in range(len(x_coords_fluid)):
        x0 = x_coords_fluid[i]
        y0 = y_coords_fluid[i]
        p0 = np.array([x0,y0])

        #Find all closest posts
        x_range = calc_range(x0, dist=dist_max, max_val=w)
        y_range = calc_range(y0, dist=dist_max, max_val=h)
        inx_rows_near = np.where((y0_list <= y_range[1]) & (y0_list >= y_range[0]))[0]   

        if len(inx_rows_near) == 0:
            raise ValueError(f'Could not find any close rows,\n for x0,y0 = ({x0},{y0}), dist_max = {dist_max} pixels.')            
        #Initialize the closest position
        inx_closest = [inx_rows_near[0],0] #row, col
        p = list_centroids_rows[inx_closest[0]][inx_closest[1]]
        dist_closest = math.dist(p0,p)    

        for inx_row_near in inx_rows_near:
            row = list_centroids_rows[inx_row_near]
            for j,p in enumerate(row):
                dist = math.dist(p0,p) 
                if dist < dist_closest:
                    dist_closest = dist
                    inx_closest = [inx_row_near,j]    

        if len(xycoords_regions[inx_closest[0]][inx_closest[1]]) == 0: #if empty        
            xycoords_regions[inx_closest[0]][inx_closest[1]] = p0
        else:
            a = xycoords_regions[inx_closest[0]][inx_closest[1]]        
            xycoords_regions[inx_closest[0]][inx_closest[1]] = np.vstack((a, p0))

        a = xycoords_regions_with_posts[inx_closest[0]][inx_closest[1]]         
        xycoords_regions_with_posts[inx_closest[0]][inx_closest[1]] = np.vstack((a, p0))
        #print(f'i={i}, for p=({x0},{y0}), post centroid closest = ({list_centroids_rows[inx_closest[0]][inx_closest[1]]}) ({inx_closest[1]},{inx_closest[0]}) (row,col) with a distance of {dist_closest:.2}')

    return xycoords_regions, xycoords_regions_with_posts

def calc_mean_value_for_post_region(img_stack, xycoords_regions):
    """Calculate the mean value for each pixel region
    
    """
    n_frames = len(img_stack)
    
    #For each frame
    means_regions = []
    for frame_nbr in range(n_frames):
        img = img_stack[frame_nbr]        
        means_regions_frame = []    
        for row in xycoords_regions:
            means_regions_row = []
            for xycoords_region in row:
                if len(xycoords_region.shape) == 2:
                    x = xycoords_region[:,0]
                    y = xycoords_region[:,1] 
                    if np.any(y == -1) or np.any(x == -1):
                        mean_val = 0
                    else:
                        mean_val = np.mean(img[y,x]) 
                else:
                    mean_val = 0
                means_regions_row.append(mean_val)
            means_regions_frame.append(means_regions_row)
        means_regions.append(means_regions_frame)    
    return means_regions

def gen_mega_pixel_array(img_stack, means_regions, xycoords_regions_with_posts):
    """
    Generate an megapixel array based on the post region mean values 
    and the corresponding pixel coordinates (post and neighbouring fluid)
    """

    img_means = np.zeros_like(img_stack)
    n_frames = len(img_stack)

    #go through the list of coordinates for the regions with posts. 
    #In a new array, set the values of these coordinates to the mean value
    for frame_nbr in range(n_frames):
            for i,row in enumerate(xycoords_regions_with_posts):
                for j,xycoords_region in enumerate(row): 
                    x = xycoords_region[:,0]
                    y = xycoords_region[:,1]     
                    mean_val = means_regions[frame_nbr][i][j]
                    img_means[frame_nbr, y,x] = mean_val

    return img_means


def gen_compressed_pixel_array_quad_norot(img_stack, means_regions, list_centroids, settings_device_type, settings_general):
    n_cols, n_rows= calc_n_rows_and_cols(list_centroids,settings_device_type, settings_general)
    sort_dir = settings_general['sort_dir']
    n_frames = len(img_stack)
    img_means = np.zeros([n_frames, n_rows, n_cols], dtype=np.uint8)
    img_means_compressed = np.zeros([n_frames, n_rows, n_cols], dtype=np.uint8)

    
    #go through the list of coordinates for the regions with posts. 
    #In a new array, set the values of these coordinates to the mean value
    
    for frame_nbr in range(n_frames):
        if sort_dir == 'cols':
            for i,col in enumerate(list_centroids):
                for j,region in enumerate(col): 
                    x = i #x-position in new array
                    y = j #y-position in new array
                    # if means_regions[frame_nbr][i].size == 0 or col == np.array[0,0]:
                    #     mean_val = 0
                    #     print(f'mean value zero for {i}, {j}')
                    # else:
                    mean_val = means_regions[frame_nbr][i][j]                    
                    img_means[frame_nbr, y, x] = mean_val
        elif sort_dir == 'rows':
            for i,row in enumerate(list_centroids):
                for j in range(len(row)): 
                    x = j #x-position in new array
                    y = i #y-position in new array
                    mean_val = means_regions[frame_nbr][i][j]
                    img_means[frame_nbr, y,x] = mean_val
    return img_means, img_means_compressed

def get_rows_list(list_dz_centers, settings_general, settings_device_type):
    if 'sort_dir' in settings_general:
        sort_dir = settings_general['sort_dir']
    else:
        sort_dir = 'rows'  

    if sort_dir == 'cols':
        margin =  settings_device_type['x_margin']
    elif sort_dir == 'rows' and settings_device_type['device_type'] == 'H':
        margin =  settings_device_type['y_margin']
    else:
        raise ValueError('This option is not eligible')

    # #Convert list of dz centers into a 1D-array
    list_dz_centers_1D = np.vstack(list_dz_centers)

    list_rows_mean_y_pos = find_row_coords(list_dz_centers_1D, settings_device_type, margin=5)
    # #Sort the list depending on x or y-axis
    # # if sort_axis == 'y', the list will be sorted by rows and then columns inside each row. 
    # list_dz_centers_sorted_by_rows = sort_only_centroid_coords_by_rows(list_dz_centers_1D, settings_device_type, margin=margin)
    n_rows = len(list_rows_mean_y_pos)
    #Iterate all rows
    list_pos = []
    for i_col,col in enumerate(list_dz_centers):  
        list_pos_in_col = []
        #Iterate all dead zones in the column
        for j,centroid in enumerate(col):
            y = centroid[1]
            i_row = np.where((list_rows_mean_y_pos <= y+margin) & (list_rows_mean_y_pos >= y-margin))[0]
            if i_row.size > 0:
                list_pos_in_col.append([i_col, i_row[0]])
            else:
                raise ValueError('could not find it.')
        list_pos.append(list_pos_in_col)
    # print(f'\t\t\tn_rows = {n_rows}, n_cols = {n_cols}')  
    return list_pos, n_rows  

def get_cols_list(list_dz_centers, settings_general, settings_device_type):
    if 'sort_dir' in settings_general:
        sort_dir = settings_general['sort_dir']
    else:
        sort_dir = 'rows'  

    if sort_dir == 'cols':
        margin =  settings_device_type['x_margin']
    elif sort_dir == 'rows' and settings_device_type['device_type'] == 'H':
        margin =  settings_device_type['y_margin']
    else:
        raise ValueError('This option is not eligible')

    # #Convert list of dz centers into a 1D-array
    list_dz_centers_1D = np.vstack(list_dz_centers)

    list_cols_mean_x_pos = find_column_coords(list_dz_centers_1D, settings_device_type, margin=5)

    # Sort the list depending on x or y-axis
    n_cols = len(list_cols_mean_x_pos)

    #Iterate all rows
    list_pos = []
    for i_row,row in enumerate(list_dz_centers):  
        list_pos_in_row = []
        #Iterate all dead zones in the column
        for j,centroid in enumerate(row):
            x = centroid[0]
            i_col = np.where((list_cols_mean_x_pos <= x+margin) & (list_cols_mean_x_pos >= x-margin))[0]
            if i_col.size > 0:
                list_pos_in_row.append([i_col[0], i_row])
            else:
                raise ValueError('could not find it.')
        list_pos.append(list_pos_in_row)
    # print(f'\t\t\tn_rows = {n_rows}, n_cols = {n_cols}')  
    return list_pos, n_cols

def gen_compressed_pixel_array_quad_rot45(img_stack, means_regions, list_centroids, settings_device_type, settings_general):
    sort_dir = settings_general['sort_dir']
    # n_cols, n_rows= calc_n_rows_and_cols(list_centroids,settings_device_type, settings_general)
    inxs_rows_cols, n_rows = get_rows_list(list_centroids, settings_general, settings_device_type)
    n_cols = len(list_centroids)

    n_frames = len(img_stack)
    img_means = np.zeros([n_frames, n_rows, n_cols], dtype=np.uint8)
    img_means_compressed_1D = np.zeros([n_frames,n_rows*n_cols], dtype=np.uint8)
    
    for frame_nbr in range(n_frames):
        k=0
        if sort_dir == 'cols':
            for i,col in enumerate(list_centroids):
                for j,centroid in enumerate(col): 
                    x = i #x-position in new array
                    # y = j #y-position in new array
                    y = inxs_rows_cols[i][j][1]
                    # find_out_what_row_the_post_is_in(centroid, )
                    mean_val = means_regions[frame_nbr][i][j]
                    img_means[frame_nbr, y, x] = mean_val
                    img_means_compressed_1D[frame_nbr,k] = mean_val
                    k=k+1
        elif sort_dir == 'rows':
            for i,row in enumerate(list_centroids):
                for j in range(len(row)): 
                    x = j #x-position in new array
                    y = i #y-position in new array
                    mean_val = means_regions[frame_nbr][i][j]
                    img_means[frame_nbr, y,x] = mean_val
    # print('img means')
    # ff.show_img(img_means[0], cmap='gray')
    return img_means, img_means_compressed_1D

def gen_compressed_pixel_array_quad(img_stack, means_regions, list_centroids, settings_device_type, settings_general, d_vid):
    """
    Generate an megapixel array based on the post region mean values 
    and the corresponding pixel coordinates (post and neighbouring fluid)
    """
    print(f'\t\tGenerating a compressed pixelated array')  

    if abs(45 - abs(d_vid['angle'])) < 10:
        #The video has been rotated 45 degrees approximately
        img_means, img_means_compressed = gen_compressed_pixel_array_quad_rot45(img_stack, means_regions, list_centroids, settings_device_type, settings_general)
    else:
        img_means, img_means_compressed = gen_compressed_pixel_array_quad_norot(img_stack, means_regions, list_centroids, settings_device_type, settings_general)

    return img_means, img_means_compressed

def gen_pixelated_array(img_stack, settings_general, settings_device_type, v, i_img, xycoords_regions, list_centroids_rows, d_vid, gen_xycoords_regions_with_posts = True, xycoords_regions_with_posts=0):
    """"Generate a pixelated array
    
    Note that dist_max will depend on the magnification.
    """
    print(f'\tGenerating a pixelated array')  
    tic = time.perf_counter()                                                                           

    #Calculate the mean value of each region (excluding the posts)
    means_regions = calc_mean_value_for_post_region(img_stack, xycoords_regions)
    
    if gen_xycoords_regions_with_posts:
        #Take the mean values of each region and fill the posts and the regions with it for each frame
        img_means = gen_mega_pixel_array(img_stack, means_regions, xycoords_regions_with_posts)
        str_side='mega_'
    else:
        if settings_device_type['device_type'] == 'Q':
            img_means = gen_compressed_pixel_array_quad(img_stack, means_regions, list_centroids_rows, settings_device_type, d_vid)
            str_side = 'compressed_'

    toc = time.perf_counter()
    print(f'\tGenerated a superpixel array in {toc - tic:0.2f} seconds')  

    #Save pixelated array to file
    if settings_general['mode'] == 'polarization':
        if i_img == 0:
            str_side = str_side+'perpen'
        elif i_img==1:
            str_side = str_side+'parallel'
        elif i_img==2:
            str_side = str_side+'tot'
        print('\tPolarization mode: '+str_side)
        file_name_img_pixelated = 'pix_'+str_side+'_'+str(v.frame_range[0])+'-'+str(v.frame_range[-1])+'_'+v.file_name0+'.tif'
    else:
        file_name_img_pixelated = 'pix_'+str(v.frame_range[0])+'-'+str(v.frame_range[-1])+'_'+v.file_name0+'.tif'
    
    #Save in sub directory 'pixelated'
    dir_pixelated = os.path.join(v.dir_pixelated_files, 'pixelated')
    if not os.path.exists(dir_pixelated):
        os.mkdir(dir_pixelated)    
    path_img_pixelated = os.path.join(dir_pixelated, file_name_img_pixelated) 
    if settings_general['save_imgs']:
        ti.write_tif_file(path_img_pixelated, img_means, photometric='minisblack')
        if settings_general['open_saved_file_dir_after_script']:
            os.startfile(dir_pixelated) 
    return img_means

def save_pix_array(img, v, settings_general, str_side='', compression=''):
#Save pixelated array to file
    if settings_general['mode'] == 'polarization':
        naming_based_on_img_nbr = True
        if 'only_process_I_tot_polarization' in settings_general:
            if settings_general['only_process_I_tot_polarization']:
                naming_based_on_img_nbr = False
                str_side = str_side+'tot'
        if naming_based_on_img_nbr:
            # raise ValueError('This option is not great atm, something is wrong')
            #[I_perpen, I_parallel, I_tot]
            if v.img_no == 0:
                str_side = str_side+'perpen'
            elif v.img_no==1:
                str_side = str_side+'parallel'
            elif v.img_no==2:
                str_side = str_side+'tot'
        print('\tPolarization mode: '+str_side)
        file_name_img_pixelated = 'dz_pix_'+str_side+'_'+str(v.frame_range[0])+'-'+str(v.frame_range[-1])+'_'+v.file_name0+'.tif'
    else:
        file_name_img_pixelated = 'dz_pix_'+str(v.frame_range[0])+'-'+str(v.frame_range[-1])+'_'+v.file_name0+'.tif'
    
    #Save in sub directory 'pixelated'
    if settings_general['save_imgs']:
        dir_pixelated = v.dir_pixelated_files
        if not os.path.exists(dir_pixelated):
            os.mkdir(dir_pixelated)   
        if compression == '1_pix_per_dz':
            dir_pixelated = os.path.join(dir_pixelated, '1_pix_per_dz')
            if not os.path.exists(dir_pixelated):
                os.mkdir(dir_pixelated)            
        path_img_pixelated = os.path.join(dir_pixelated, file_name_img_pixelated) 
        ti.write_tif_file(path_img_pixelated, img, photometric='minisblack')
        if settings_general['open_saved_file_dir_after_script']:
            os.startfile(dir_pixelated) 


def gen_pixelated_dz_array(img_stack, settings_general, settings_device_type, v, list_dz_centers, list_dz_coords, d_vid, gen_xycoords_regions_with_posts = False, xycoords_regions_with_posts=0, split_hex=False, hex_label='', plot_post_centers=False):
    """"Generate a pixelated array
    
    Note that dist_max will depend on the magnification.
    """
    print(f'\tGenerating a dead zone pixel array')  
    tic = time.perf_counter()                                                                           

    #Only analyze the areas upstream of the posts? (For control array only)
    if 'sparse_array_only_analyze_upstream_of_posts' in settings_general:
        analyze_upstream_areas_only = settings_general['sparse_array_only_analyze_upstream_of_posts']
    else:
        analyze_upstream_areas_only = False
    if analyze_upstream_areas_only:
        extra_label = 'upstream_'
    else:
        extra_label = ''

    #Calculate the mean value of each region
    means_regions = calc_mean_value_for_post_region(img_stack, list_dz_coords)
    
    if gen_xycoords_regions_with_posts:
        #Take the mean values of each region and fill the posts and the regions with it for each frame
        # img_means = gen_mega_pixel_array(img_stack, means_regions, xycoords_regions_with_posts)
        # str_side='mega_'
        print('option not possible')
    else:
        if settings_device_type['device_type'] == 'Q':
            img_means, img_means_verycompressed = gen_compressed_pixel_array_quad(img_stack, means_regions, list_dz_centers, settings_device_type, settings_general, d_vid)
            str_side = 'compressed_'
            str_side = 'compressed_hex_'
   
    toc = time.perf_counter()
    print(f'\tGenerated a dead zone pixel array {img_means.shape} in {toc - tic:0.2f} seconds')  

    if settings_general['show_imgs']:
        if 'show_img_zoom_window' in settings_general:
            show_img_zoom_window = settings_general['show_img_zoom_window']
        else:
            show_img_zoom_window = [[150,200],[200,250]]
        visualize_regions(img_stack[v.frame_to_display], v, list_dz_coords, list_dz_centers, plot_post_centers=plot_post_centers, post_center_color='white', save_img = False, with_bg = True, title=f'Found dead zones (colored) on top of frame #{v.frame_to_display}')
        visualize_regions(img_stack[v.frame_to_display], v, list_dz_coords, list_dz_centers, plot_post_centers=plot_post_centers, post_center_color='white', save_img = False, with_bg = True, title=f'Found dead Zones (colored) on top of frame #{v.frame_to_display} ', xlim=show_img_zoom_window[0], ylim=show_img_zoom_window[1], axis_on=True, hollow_rect=True)
        ff.show_img(img_means[v.frame_to_display], title_text = f'frame #{v.frame_to_display}, dz pixel array {img_means.shape}, {v.mag}, {v.p} mbar', cmap='gray', sz=10)
        ff.show_img(np.mean(img_means, axis=0), title_text = f'Mean intensity of the dz pixel array {img_means.shape}, {v.mag}, {v.p} mbar', cmap='gray', sz=10)
    #Save pixelated array to file
    save_pix_array(img_means, v, settings_general, str_side=extra_label+hex_label+str_side)
    if settings_device_type['device_type'] == 'H' or settings_device_type['device_type'] == 'C' or abs(45 - abs(d_vid['angle'])) < 10:

        save_pix_array(img_means_verycompressed, v, settings_general, str_side=extra_label+hex_label+'1_pix_per_dz_', compression='1_pix_per_dz')

import pickle
def save_vars(v, list_centroids_rows, list_object_coords_rows, mask_fluid, xycoords_regions, xycoords_regions_with_posts, prefix=''):
    """[Saves the following variables to a pickled file:
        list_centroids_rows, 
        list_object_coords_rows, 
        mask_fluid, 
        xycoords_regions, 
        xycoords_regions_with_posts    
    ]

    Args:
        v ([type]): [description]
        list_centroids_rows ([type]): [description]
        list_object_coords_rows ([type]): [description]
        mask_fluid ([type]): [description]
        xycoords_regions ([type]): [description]
        xycoords_regions_with_posts ([type]): [description]
    """
    # Saving the objects:
    print('\t Saving variables to file.')
    dir_pixelated = os.path.join(v.dir_pixelated_files, 'pixelated')
    if prefix == 'dead_zone':
        dir_pixelated = 'dz_'+dir_pixelated
    if not os.path.exists(dir_pixelated):
        os.mkdir(dir_pixelated)    
    file_name = 'vars.pkl'
    if prefix == 'dead_zone':
        file_name = 'dz_'+file_name
    file_path_vars = os.path.join(dir_pixelated, file_name)

    with open(file_path_vars, 'wb') as f:
        pickle.dump([list_centroids_rows, list_object_coords_rows, mask_fluid, xycoords_regions, xycoords_regions_with_posts], f)

def save_vars_dz(v, list_centroids_rows, list_object_coords_rows, mask_fluid, prefix=''):
    """[Saves the following variables to a pickled file:
        list_centroids_rows, 
        list_object_coords_rows, 
        mask_fluid, 
        xycoords_regions, 
        xycoords_regions_with_posts    
    ]

    Args:
        v ([type]): [description]
        list_centroids_rows ([type]): [description]
        list_object_coords_rows ([type]): [description]
        mask_fluid ([type]): [description]
        xycoords_regions ([type]): [description]
        xycoords_regions_with_posts ([type]): [description]
    """
    # Saving the objects to file
    
    dir_pixelated = v.dir_pixelated_files
    # if prefix == 'dead_zone':
    #     dir_pixelated = 'dz_'+dir_pixelated
    if not os.path.exists(dir_pixelated):
        os.mkdir(dir_pixelated)    
    file_name = 'vars.pkl'
    if prefix == 'dead_zone':
        file_name = 'dz_'+file_name
    file_path_vars = os.path.join(dir_pixelated, file_name)
    print(f'\t Saving variables to pickle file ({file_path_vars})')

    with open(file_path_vars, 'wb') as f:
        pickle.dump([list_centroids_rows, list_object_coords_rows, mask_fluid], f)

def load_vars(df_expd):
    """[Loads the following variables from a pickled file:
        list_centroids_rows, m
        list_object_coords_rows, 
        mask_fluid, 
        xycoords_regions, 
        xycoords_regions_with_posts
    ]

    Args:
        df_exp ([type]): [description]
        df_vid ([type]): [description]

    Returns:
        [type]: [description]
    """

    file_nbr_load_from = df_vid['mask_file_nbr'].values[0] #Find the file number to where to load the variables from.
    print(f'Loading mask data from {file_nbr_load_from} in the exp_list.')
    #If it could not find a file number to load from, it is probably itself, so take the file number from itself instead and load variables from previously run session
    if np.isnan(float(file_nbr_load_from)):
        file_nbr_load_from = df_vid.mask_file_nbr.values[0]
    df_vid_load_from = df_exp[df_exp['file_nbr'] == file_nbr_load_from] #Find the right data frame settings.
    if len(df_vid_load_from) == 0:
        raise ValueError(f'Could not find the file with file number {file_nbr_load_from} in the exp_list.')

    # Find the right video settings (only need the directory)
    v_file_nbr_load_from = ti.vid(dir_exp=df_vid_load_from['dir_exp'].values[0], 
           spreadsheet_file_name=df_vid['spreadsheet_file_name'].values[0], 
           file_nbr=file_nbr_load_from)

    # Set the path to where to load the file from
    path_file_nbr_load_from = os.path.join(v_file_nbr_load_from.dir_save, 'pixelated', 'vars.pkl')
    if not os.path.isfile(path_file_nbr_load_from):
        raise ValueError(f'Could not find saved files with masking settings variables,\npath = {path_file_nbr_load_from}')

    print(f'\tLoaded variables from file {file_nbr_load_from} with mag={v_file_nbr_load_from.mag} and p={v_file_nbr_load_from.p} mbar')
    # Getting back the objects:
    with open(path_file_nbr_load_from, 'rb') as f: 
        list_centroids_rows, list_object_coords_rows, mask_fluid, xycoords_regions, xycoords_regions_with_posts = pickle.load(f)
    return list_centroids_rows, list_object_coords_rows, mask_fluid, xycoords_regions, xycoords_regions_with_posts

def load_vars_dz(df_exp, df_vid, prefix=''):
    """[Loads the following variables from a pickled file:
        list_centroids_rows, m
        list_object_coords_rows, 
        mask_fluid, 
        xycoords_regions, 
        xycoords_regions_with_posts
    ]

    Args:
        df_exp ([type]): [description]
        df_vid ([type]): [description]

    Returns:
        [type]: [description]
    """

    file_nbr_load_from = df_vid['mask_file_nbr'].values[0] #Find the file number to where to load the variables from.
    print(f'====================================================================================\n==================Loading mask data from file #{file_nbr_load_from} in the exp_list.==================\n====================================================================================')
    #If it could not find a file number to load from, it is probably itself, so take the file number from itself instead and load variables from previously run session
    if np.isnan(float(file_nbr_load_from)):
        file_nbr_load_from = df_vid.mask_file_nbr.values[0]
    df_vid_load_from = df_exp[df_exp['file_nbr'] == file_nbr_load_from] #Find the right data frame settings.
    if len(df_vid_load_from) == 0:
        raise ValueError(f'Could not find the file with file number {file_nbr_load_from} in the exp_list.')

    file_path = df_vid_load_from.file_path.values[0]
    v = nd2_handling2.Video_waves(file_path, read_img_directly=False, frame_range=[0,0], dir_exp=df_vid.dir_exp.values[0])

    # Set the path to where to load the file from
    path_file_nbr_load_from = os.path.join(v.dir_pixelated_files, 'dz_vars.pkl')
    if not os.path.isfile(path_file_nbr_load_from):
        raise ValueError(f'Could not find saved files with masking settings variables,\npath = {path_file_nbr_load_from}')

    print(f'\tLoaded variables from pickled file {file_nbr_load_from} with mag={v.mag} and p={v.p} mbar')
    # Getting back the objects:
    with open(path_file_nbr_load_from, 'rb') as f: 
        list_centroids_rows, list_object_coords_rows, mask_fluid = pickle.load(f)
    return list_centroids_rows, list_object_coords_rows, mask_fluid

def pre_processing_post_vicinity(v, df_vid, d_vid, settings_general, settings_device_type):
    """[summary]

    Args:
        v ([video object]): [description]
        df_vid ([pandas data frame]): [description]
        settings_general ([dictionary]): [description]
        settings_device_type ([dictionary]): [description]

    Returns:
        [type]: [description]
    """
    print('\n=== Commencing pre-processing ===')
    #Load or create median image(s) from the stack
    img_med = retrieve_med_file(settings_general, v, df_vid, d_vid)

    # Generate masks of the non-fluid pixels (posts and side walls)                    
    mask_fluid, mask_posts, img_med, mask_conv_3 = gen_masks(img_med, v, df_vid, d_vid, settings_general, settings_device_type)

    # Find the xy-coordinates of the post centroids using weighted centroids in regionprops
    list_centroids_sorted, list_object_coords_sorted = find_post_centers(mask_posts, img_med, settings_general, settings_device_type, df_vid, disp_details = True)

    #Calculate the maximum distance
    dist_max=settings_device_type['dist_max']
    if 'dist_max' in df_vid.columns: 
        if not np.isnan(df_vid['dist_max'].values[0]):
            dist_max = df_vid['dist_max'].values[0]
    #Calculate the closest post for each pixel that is within the "fluid"
    xycoords_regions, xycoords_regions_with_posts = calc_closest_post_of_each_pixel(list_centroids_sorted, 
                                                                                    list_object_coords_sorted, 
                                                                                    mask_fluid, 
                                                                                    dist_max=dist_max)              
    # Save variables to a pickle-file
    save_vars(v, list_centroids_sorted, list_object_coords_sorted, mask_fluid, xycoords_regions, xycoords_regions_with_posts)

    return list_centroids_sorted, list_object_coords_sorted, mask_fluid, xycoords_regions, xycoords_regions_with_posts

def pre_processing_dead_zone(v, df_vid, d_vid, settings_general, settings_device_type):
    """[summary]

    Args:
        v ([video object]): [description]
        df_vid ([pandas data frame]): [description]
        settings_general ([dictionary]): [description]
        settings_device_type ([dictionary]): [description]

    Returns:
        [type]: [description]
    """
    print('Commencing pre-processing of Dead Zone Analysis')
    #Load or create median image(s) from the stack
    img_med = retrieve_med_file(settings_general, v, df_vid, d_vid) 
    
    # Generate masks of the non-fluid pixels (posts and side walls)                    
    mask_fluid, mask_posts, img_med, mask_conv_3 = gen_masks(img_med, v, df_vid, d_vid, settings_general, settings_device_type)

    # Find the xy-coordinates of the post centroids using weighted centroids in regionprops
    list_centroids_sorted, list_object_coords_sorted = find_post_centers(mask_posts, img_med, settings_general, settings_device_type, df_vid, disp_details = True)
      
    #Save variables to a pickle-file
    save_vars_dz(v, list_centroids_sorted, list_object_coords_sorted, mask_fluid, prefix='dead_zone')

    return list_centroids_sorted, list_object_coords_sorted, mask_fluid#, xycoords_regions, xycoords_regions_with_posts


def show_masking_results(img, mask_fluid, v, settings_general, df_vid, xycoords_regions, list_centroids_sorted):

    if (np.isnan(float(df_vid.mask_file_nbr.values[0]))) and (settings_general['show_imgs']): 
        if settings_general['show_imgs']:
            if 'show_img_zoom_window' in settings_general:
                show_img_zoom_window = settings_general['show_img_zoom_window']
            else:
                show_img_zoom_window = [[150,200],[200,250]]
        visualize_regions(img[v.frame_to_display], v, xycoords_regions, list_centroids_sorted, plot_post_centers=False, post_center_color='white', save_img = True)
        visualize_regions(img[v.frame_to_display], v, xycoords_regions, list_centroids_sorted, plot_post_centers=False, save_img = True, with_posts = True, xlim=show_img_zoom_window[0], ylim=show_img_zoom_window, hollow_rect=True)

    if settings_general['show_imgs'] or settings_general['show_masking_results'] :
        mask_fluid_bool =  mask_fluid != 0
        img_to_display = img[v.frame_to_display]
        out = ff.mask_gray_img(img_to_display, mask_fluid_bool, color=[255, 0, 0], alpha=0.5)
        ff.show_img(out, sz=20, title_text=f'Fluid mask overlaid on the frame {v.frame_to_display}')

        out = ff.mask_gray_img(img_to_display, np.invert(mask_fluid_bool), color=[255, 0, 0], alpha=0.5)
        ff.show_img(out, sz=20, title_text=f'Fluid mask overlaid on the frame {v.frame_to_display}')

def determine_if_pre_process_or_not(settings_general, df_vid):
    #Load masking settings        
    mask_file_nbr = df_vid.mask_file_nbr.values[0]
    pre_process = True 
    if isinstance(mask_file_nbr, (str)):
        if mask_file_nbr == ' ':
            pre_process = True   
        try:
            if float(mask_file_nbr) > 0:
                pre_process = False
        except:
            print('Did not find a mask file number to load the masking settings from.')
    else:
        if not np.isnan(float(mask_file_nbr)) : #Load masking settings from file
            pre_process = False
    print('mask_file_nbr ',mask_file_nbr, type(mask_file_nbr))
    if settings_general['load_masking_settings_from_file']:
        pre_process = False
    return pre_process
def pre_process(settings_general, df_exp, df_vid, d_vid, v, settings_device_type):
    #Pre-processing based on the selected pixelation mode
    if settings_general['mode_pixelation'] == 'post_vicinity':
        #Load masking settings
        if not np.isnan(float(df_vid.mask_file_nbr.values[0])) or settings_general['load_masking_settings_from_file']: #Load masking settings from file
            list_centroids_rows, list_object_coords_rows, mask_fluid, xycoords_regions, xycoords_regions_with_posts = load_vars(df_exp, df_vid)
        else:
            #Pre-processing
            list_centroids_rows, list_object_coords_rows, mask_fluid, xycoords_regions, xycoords_regions_with_posts = pre_processing_post_vicinity(v, df_vid, d_vid, settings_general, settings_device_type)
            if 'perform_only_preprocessing' in settings_general:
                if settings_general['perform_only_preprocessing']:
                    sys.exit("Error message")
    
    elif settings_general['mode_pixelation'] == 'dead_zone':
        pre_process = determine_if_pre_process_or_not(settings_general, df_vid)
        if pre_process:
            list_centroids_sorted, list_object_coords_sorted, mask_fluid = pre_processing_dead_zone(v, df_vid, d_vid, settings_general, settings_device_type)
            if 'perform_only_preprocessing' in settings_general:
                if settings_general['perform_only_preprocessing']:
                    sys.exit("Qutting program after pre-processing, see the perform_only_preprocessing parameter in settings_general")
        else:
            print(f'\tmask_file_nbr: {df_vid.mask_file_nbr.values[0]}')
            list_centroids_sorted, list_object_coords_sorted, mask_fluid = load_vars_dz(df_exp, df_vid)  
    else:
        raise ValueError(r'Wrong value of \'mode_pixelation\'')
    return list_centroids_sorted, list_object_coords_sorted, mask_fluid

def generate_pixelated_array(img, mask_fluid, settings_general, df_vid, d_vid, v, settings_device_type, list_centroids_sorted):
    if settings_general['mode_pixelation'] == 'post_vicinity':
        # show_masking_results(img, mask_fluid, v, settings_general, df_vid, xycoords_regions, list_centroids_sorted)
        img_no = 0
        # # Generate pixelated array
        # _ = gen_pixelated_array(img,                                                            
        #                         settings_general,
        #                         settings_device_type, 
        #                         v, 
        #                         img_no,
        #                         xycoords_regions, 
        #                         xycoords_regions_with_posts,
        #                         list_centroids_sorted)

    elif settings_general['mode_pixelation'] == 'dead_zone':
        if d_vid['device_type'] == 'Q':
            list_dz_centers, list_dz_coords = find_dead_zone_centers_and_coords_quad(img, mask_fluid, v, settings_general, list_centroids_sorted, w_dz = 3, h_dz = 2)
            gen_pixelated_dz_array(img,                                                            
                                    settings_general,
                                    settings_device_type, 
                                    v, 
                                    list_dz_centers, 
                                    list_dz_coords,
                                    d_vid)

def double_check_dimensions(img, mask_fluid):
    # Double-check that the dimensions of the image and fluid mask are the same
    if img.shape != mask_fluid.shape:
        ff.show_img(img, title=f'Image')
        ff.show_img(mask_fluid, title='Mask Fluid')
        raise ValueError(f'The dimensions of the image {img.shape} is not the same as the fluid mask {mask_fluid.shape}.\n One of them probably needs rotation:')
    
def array_pixelation(file_path, settings_general, d):
    """Creates a pixelated array."""

    #Load parameters from dictionary d
    if 'n_frames_limit' in d: #Maximum number of frames to process in a single session (to minmimize too high CPU load at the same time)
        n_frames_limit = d['n_frames_limit']
    else:
        n_frames_limit = 500
    if 'df_vid' in d:
        df_vid = d['df_vid']
    if 'd_vid' in d:
        d_vid = d['d_vid']
    else:
        raise ValueError('Could not find df_vid in the dictionary')
    if 'df_exp' in d:
            df_exp = d['df_exp']
    else:
        raise ValueError('Could not find df_exp in the dictionary')
    if 'dict_list_of_exp' in d:
            dict_list_of_exp = d['dict_list_of_exp']
    else:
        raise ValueError('Could not find dict_list_of_exp in the dictionary')

    #Display images of the pre-processing or not? If this parameter is lacking from the dictionary, display is turned off
    if not 'show_imgs_preprocessing' in settings_general:
        settings_general['show_imgs_preprocessing'] = False

    #Retrive the parameters for the device type into a dictionary
    settings_device_type = wg.generate_device_type_settings_dict(df_vid)
    
    #Read file (files if the image was recorded with optosplit)
    imgs, v = read_file_array_pixelation(df_vid, d_vid, settings_general) 

    #Update the data frame subset with metadata and information from the file name
    df_vid = gen_video_df(v, df_vid)         
    
    #Pre-processing based on the selected pixelation mode
    list_centroids_sorted, list_object_coords_sorted, mask_fluid = pre_process(settings_general, df_exp, df_vid, d_vid, v, settings_device_type)

    #imgs is a list of images, e.g. left and right side or just one item in case of normal, non-optosplit mode
    for img_no,img in enumerate(imgs):
        print(f'\nAnalyzing the image stack {img_no+1}/{len(imgs)} ('+settings_general['mode']+'):')

        # Double-check that the dimensions of the image and fluid mask are the same
        double_check_dimensions(img[0], mask_fluid)

        only_perform_create_img_posts_masked_out = False
        if 'only_perform_create_img_posts_masked_out' in settings_general:
            if settings_general['only_perform_create_img_posts_masked_out'] == True:                
                only_perform_create_img_posts_masked_out = True
                print('!!! Only creating the image with the posts masked out and no pixelation!')
        if not only_perform_create_img_posts_masked_out:
            v.img_no = img_no
            generate_pixelated_array(img, mask_fluid, settings_general, df_vid, d_vid, v, settings_device_type, list_centroids_sorted)

        if settings_general['create_img_posts_masked_out']:
            #Save the image where the posts have been masked out
            create_img_posts_masked_out(img, mask_fluid, v, settings_general, dict_list_of_exp)

    #Return an empty dictionary
    dict_var_fun_specific = {}
    return dict_var_fun_specific

def create_img_posts_masked_out(img, mask_fluid, v, settings_general, dict_list_of_exp, save_also_raw_img=True):
    
    #Set all pixels that are not considered "fluid" as 0. They will not be added when finding the average intensity.
    if save_also_raw_img:
        file_name = 'raw_'+str(v.frame_range[0])+'-'+str(v.frame_range[-1])+'_'+v.file_name0+'.tif'
        file_path = os.path.join(v.dir_img_no_posts, file_name)
        ti.write_tif_file(file_path, img, photometric='minisblack')

    img[:,~mask_fluid] = 0
    #print(dict_list_of_exp)
    c=dict_list_of_exp['conc']
    print(f'Conc. = {c} ng/uL')
    if c == 0 or c == ' ':
        raise ValueError('Incorrect value of the concentration (={c})')
    if settings_general['show_imgs']:
        pa.show_local_conc_img(img, v, settings_general,c=c)    
        pa.show_histogram_local_conc(img, v,settings_general, c = c)

    if settings_general['save_imgs']:
        str_extra = ''
        file_name = str_extra+'posts_masked_out_'+str(v.frame_range[0])+'-'+str(v.frame_range[-1])+'_'+v.file_name0+'.tif'
        file_path = os.path.join(v.dir_img_no_posts, file_name)
        ti.write_tif_file(file_path, img, photometric='minisblack')
        if settings_general['open_saved_file_dir_after_script']:
            os.startfile(v.dir_img_no_posts) 

dir_main = r'C:\Users\os4875st\.py\py_DNA_waves\waves analysis'
def get_exp_list_df():    
    file_name_exp_list = 'list_of_experiments.xlsx'
    path_exp_list = os.path.join(dir_main, 'array_pixilation', file_name_exp_list)
    #print(path_exp_list)
    df = pd.read_excel(path_exp_list, dtype={'exp_ID':str, 'pressures':list, 'device_type':str, 'experiment_type':str, 'pressures_to_ignore':float})
    return df

def make_coords_array(mask_fluid, x_range, y_range):
    """Find the allowed coordinates of the dead zone based on the fluid mask and a rectangle generated by a pre-defined width and height (x and y-ranges)
    """
    h_img, w_img = mask_fluid.shape
    xy = []
    #Iterate the potential dead zone pixels
    for x in x_range:
        for y in y_range:
            #coordinates should not be out of bounds or inside a post
            if ((x < w_img) and (y < h_img)) and mask_fluid[y,x]:
                xy.append([x,y])
    xy = np.array(xy).astype(int)
    # if xy.size == 0:
    #     raise ValueError('xy is empty')
    return xy
def find_coords_dz(mask_fluid, x_dead_zone_center, y_dead_zone_center, h_dz = 2, w_dz = 3, h_fixed=False, edge_margin=4):
    """[Find the allowed coordinates of the dead zone based on the fluid mask and pre-defined dead zone width and height]

    Args:
        mask_fluid ([type]): [description]
        x_dead_zone_center ([type]): [description]
        y_dead_zone_center ([type]): [description]
        h_dz (int, optional): [description]. Defaults to 2.
        w_dz (int, optional): [description]. Defaults to 3.
        h_fixed (bool, optional): [description]. Defaults to False.

    Raises:
        ValueError: [description]
        ValueError: [description]
        ValueError: [description]

    Returns:
        xy [type]: [description]
    """

    #If a dead zone is too close to an edge, return an empty array
    h_img, w_img = mask_fluid.shape
    if (x_dead_zone_center < edge_margin) or (x_dead_zone_center > w_img - edge_margin) or (y_dead_zone_center < edge_margin) or (y_dead_zone_center > h_img - edge_margin):
        return np.array([])

    x_dead_zone_center_rounded = round(x_dead_zone_center)


    if w_dz%2 == 1:
        if w_dz == 3:
            x_range = np.array([x_dead_zone_center_rounded -1, x_dead_zone_center_rounded, x_dead_zone_center_rounded+1]).astype(int)
        else:
            raise ValueError('Did not yet write the script if w_dz is uneven and larger than 3')
    else:
        raise ValueError('Did not yet write the script if w_dz is even')
        x_min = round(x_dead_zone_center - w_dz/2)
        x_max = round(x_dead_zone_center + w_dz/2)
        x_range = np.arange(x_min, x_max+1)

    if h_dz == 2:
        #If the height of the dead zones is only 2 pixels, find the two pixels that are closest to the dead zone center:
        y_dead_zone_center_rounded = round(y_dead_zone_center)
        if y_dead_zone_center_rounded-y_dead_zone_center > 0:
            y_range = np.array([y_dead_zone_center_rounded-1, y_dead_zone_center_rounded])
        else:
            y_range = np.array([y_dead_zone_center_rounded, y_dead_zone_center_rounded+1])
    else:
        y_min = round(y_dead_zone_center - h_dz/2)
        y_max = round(y_dead_zone_center + h_dz/2)
        y_range = np.arange(y_min, y_max+1)

    xy = make_coords_array(mask_fluid, x_range, y_range)

    if xy.size == 0:
        print(x_range, y_range)
        print('1D')
        print('xy is empty')
        return xy
        # sz = 10
        # fig = plt.figure(figsize=(sz, sz))  # create a figure object
        # ax = fig.add_subplot(1, 1, 1)  # create an axes object in the figure
        # ax.axis('off')
        # imgplot = ax.imshow(mask_fluid, cmap='gray') 
        # window_sz = 100
        # ylim = [y_dead_zone_center-window_sz, y_dead_zone_center+window_sz]
        # xlim = [x_dead_zone_center-window_sz, x_dead_zone_center+window_sz]
        # # if not ylim[-1] == -1:
        # #     plt.ylim(ylim)
        # # if not xlim[-1] == -1:
        # #     plt.xlim(xlim)
        # plt.plot(x_dead_zone_center, y_dead_zone_center, 'o', color='red', markersize=1)
        # plt.style.use('general') 
        # plt.show()    
    #If the height should be fixed, and it is smaller than specified, make it taller. Only one side of it will be added because the other will go into the fluid mask
    n_y_values = np.unique(xy[:,1]).shape[0]
    # print(n_y_values)
    if h_fixed and (n_y_values < h_dz):
        y_range2 = np.arange(y_range[0]-1, y_range[-1]+2)
        xy = make_coords_array(mask_fluid, x_range, y_range2)
        n_y_values = np.unique(xy[:,1]).shape[0]
        if n_y_values < h_dz-1:
            raise ValueError(f'Dead zone Too narrow in height ({n_y_values} < {h_dz})\nfor dead zone with coordinates({x_dead_zone_center:.1f},{y_dead_zone_center:.1f}')
    # if row_nbr == 1:
    #     print('row equal to one')
    # if np.any(xy == 0):
    #     print('one!')
    # if x_dead_zone_center > 150:
    #     print(x_dead_zone_center)
    return xy

def add_dz_lists_for_row(mask_fluid, row, row_nbr, d_y, h_dz, w_dz, h_fixed=False):
    n_posts_in_row = len(row) #Number of dead zones per column: One less than the total number of posts in a column
    list_dz_in_row = []
    list_dz_coords_in_row = []
    #Iterate all posts per row
    for i in range(n_posts_in_row):
        x_post_center,y_post_center = row[i] #post center coordinates
        y_dead_zone_center = y_post_center-d_y #top dz y-center (y increasses from top to bottom)
        x_dead_zone_center = x_post_center
        list_dz_in_row.append(np.array([x_dead_zone_center,y_dead_zone_center]))
        xy = find_coords_dz(mask_fluid, x_dead_zone_center, y_dead_zone_center, h_dz = h_dz, w_dz = w_dz, h_fixed=h_fixed)
        list_dz_coords_in_row.append(xy)
    return list_dz_in_row, list_dz_coords_in_row


def add_dz_lists_for_col(mask_fluid, row, row_nbr, d_y, h_dz, w_dz, h_fixed=False):
    n_posts_in_row = len(row) #Number of dead zones per column: One less than the total number of posts in a column
    #Iterate all posts per row
    list_dz_in_row = []
    list_dz_coords_in_row = []
    for i in range(n_posts_in_row):
        x_post_center,y_post_center = row[i] #post center coordinates
        y_dead_zone_center = y_post_center-d_y #top dz y-center (y increasses from top to bottom)
        x_dead_zone_center = x_post_center
        list_dz_in_row.append(np.array([x_dead_zone_center,y_dead_zone_center]))
        xy = find_coords_dz(mask_fluid, x_dead_zone_center, y_dead_zone_center, h_dz = h_dz, w_dz = w_dz, h_fixed=h_fixed)
        list_dz_coords_in_row.append(xy)
    return list_dz_in_row, list_dz_coords_in_row

def find_dead_zone_centers_and_coords_quad(img, mask_fluid, v, settings_general, list_centroids_sorted, w_dz = 3, h_dz = 2):
    """[Finds the dead zones between posts in a post array]

    Args:
        img ([type]): [Image (for extracting dimensions)]
        list_centroids_sorted ([list of lists (columns) with coordinates in a numpy array]): [Posts centroids sorted column-wise]
        w_dz (int, optional): [Width of a dead zone]. Defaults to 3.
        h_dz (int, optional): [Height of a dead zone]. Defaults to 3.

    Returns:
        [list_dz_centers : [List of lists of dead zone coordinate centers]
        list_dz_coords : [List of lists of dead zone coordinates]
        ]
    """
    print('Finding dead zone centers and coordinates')
    if len(mask_fluid.shape) != 2:
        raise ValueError(f'The dimension of the mask should be 2D, not {len(mask_fluid.shape)}D')
    h_img, w_img = mask_fluid.shape

    #Initialize lists
    list_dz_centers = []
    list_dz_coords = []

    #Iterate the columns
    for k,col in enumerate(list_centroids_sorted):  
        #print('col =',k)  
        list_dz_in_col = []
        list_dz_coords_in_col = []
        n_dead_zones = len(col)-1 #Number of dead zones per column: One less than the total number of posts in a column
        
        #Iterate all dead zones in a column
        for i in range(n_dead_zones):
            #Define the dead zone center coordinates:
            if settings_general['sort_dir'] == 'cols':
                y0 = col[i][1]
                y1 = col[i+1][1]
                if abs(y1-y0) < 1:
                    raise ValueError('y0 and y1 too close')
                y_dead_zone_center = y0+(y1-y0)/2
                x_dead_zone_center = np.mean(col[:,0]) #Mean value of the x-position of all posts in the column
            elif settings_general['sort_dir'] == 'rows':
                x0 = col[i][0]
                x1 = col[i+1][0]
                if abs(x1-x0) < 1:
                    raise ValueError('x0 and x1 too close')
                x_dead_zone_center = x0+(x1-x0)/2
                y_dead_zone_center = np.mean(col[:,1]) #Mean value of the x-position of all posts in the row

            #Append the dead zone coordinate to the list of dead zone centers for the column
            list_dz_in_col.append(np.array([x_dead_zone_center,y_dead_zone_center])) 

            #Find the allowed coordinates of the dead zone based on the fluid mask and pre-defined dead zone width and height
            # sz = 10
            # fig = plt.figure(figsize=(sz, sz))  # create a figure object
            # ax = fig.add_subplot(1, 1, 1)  # create an axes object in the figure
            # ax.axis('off')
            # imgplot = ax.imshow(mask_fluid, cmap='gray') 
            # window_sz = 100
            # ylim = [y_dead_zone_center-window_sz, y_dead_zone_center+window_sz]
            # xlim = [x_dead_zone_center-window_sz, x_dead_zone_center+window_sz]
            # # if not ylim[-1] == -1:
            # #     plt.ylim(ylim)
            # # if not xlim[-1] == -1:
            # #     plt.xlim(xlim)
            # plt.plot(x_dead_zone_center, y_dead_zone_center, 'o', color='red', markersize=1)
            # plt.style.use('general') 
            # plt.show()  
            if (abs(x_dead_zone_center-23) < 1) & (abs(y_dead_zone_center-23) < 1):
                print('pause')

            xy = find_coords_dz(mask_fluid, x_dead_zone_center, y_dead_zone_center, h_dz = h_dz, w_dz = w_dz)

            list_dz_coords_in_col.append(xy)
        list_dz_centers.append(np.array(list_dz_in_col))
        list_dz_coords.append(np.array(list_dz_coords_in_col))
    if settings_general['show_imgs']: 
        if 'show_img_zoom_window' in settings_general:
            show_img_zoom_window = settings_general['show_img_zoom_window']
        else:
            show_img_zoom_window = [[150,200],[200,250]]           
        visualize_regions(mask_fluid, v, list_dz_coords, list_dz_centers, plot_post_centers=False, post_center_color='white', save_img = False, with_bg = True, title=f'Found dead zones (colored) on top of frame #{v.frame_to_display}')
        visualize_regions(mask_fluid, v, list_dz_coords, list_dz_centers, plot_post_centers=False, post_center_color='white', save_img = False, with_bg = True, xlim=show_img_zoom_window[0], ylim=show_img_zoom_window[1], title=f'Found dead Zones (colored) on top of frame #{v.frame_to_display} ', hollow_rect=True)
    
    return list_dz_centers, list_dz_coords

def calc_n_rows_and_cols(list_dz_centers, settings_device_type, settings_general):
    print('\t\tCalculating the number of rows and columns in the dz array')
    if 'sort_axis' in settings_general:
        sort_axis = settings_general['sort_axis']
    else:
        sort_axis = 'y'  

    if sort_axis == 'x':
        margin =  settings_device_type['x_margin']
    elif sort_axis == 'y' and settings_device_type['device_type'] == 'H':
        margin =  settings_device_type['y_margin']
    else:
        raise ValueError('This option is not eligible')
 
    list_dz_centers_sorted = sort_centroids_list(list_dz_centers, settings_general, settings_device_type)

    if sort_axis == 'x':  #Sorting of the array by columns
        list_sorted_rows = np.zeros([len(list_dz_centers_sorted)])
        #Iterate the columns or rows
        for i,col in enumerate(list_dz_centers_sorted):  
            y_mean = np.mean(col[:,1])
            list_sorted_rows[i] = y_mean

        #The number of columns is just the length of the list
        n_cols = len(list_dz_centers_sorted)

        #Find the maximum number of rows 
        list_n_rows= [len(row) for row in list_dz_centers_sorted]
        n_rows = np.max(np.array(list_n_rows))

        return n_cols, n_rows
    elif sort_axis == 'y':
        n_rows = len(list_dz_centers_sorted) #sorted in list with rows, each position contains all dz within the same row.

        #Find the maximum number of columns 
        list_n_cols = [len(row) for row in list_dz_centers_sorted]
        n_cols1 = np.max(np.array(list_n_cols))
        n_cols2 = n_cols1-1

        if n_rows > 50:
            raise ValueError(f'Found too many rows ({n_rows}')
        if n_cols1 > 90:
            raise ValueError(f'Found too many columns ({n_rows}')
        print(f'\t\t\tn_rows = {n_rows}, n_cols = {n_cols1} and {n_cols2}')
        return n_cols1, n_rows
